"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "b1e83b6e08542dda7586da30625a69a3c0e5e18d"
Unverified Commit 37be02a4 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] enable socket net_type for rpc (#3951)

* [Feature] enable socket net_type for rpc

* fix lint

* fix lint

* fix build issue on windows

* fix test failure on windows

* fix test failure

* fix cpp unit test failure

* net_type blocking max_try_times

* fix other comments

* fix lint

* fix comment

* fix lint

* fix cpp
parent c3baf433
......@@ -259,7 +259,7 @@ def run(args, device, data):
def main(args):
print(socket.gethostname(), 'Initializing DGL dist')
dgl.distributed.initialize(args.ip_config)
dgl.distributed.initialize(args.ip_config, net_type=args.net_type)
if not args.standalone:
print(socket.gethostname(), 'Initializing DGL process group')
th.distributed.init_process_group(backend=args.backend)
......@@ -325,6 +325,8 @@ if __name__ == '__main__':
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode')
parser.add_argument('--pad-data', default=False, action='store_true',
help='Pad train nid to the same length across machine, to ensure num of batches to be the same.')
parser.add_argument('--net_type', type=str, default='socket',
help="backend net type, 'socket' or 'tensorpipe'")
args = parser.parse_args()
print(args)
......
......@@ -174,7 +174,7 @@ class CustomPool:
def initialize(ip_config, num_servers=1, num_workers=0,
max_queue_size=MAX_QUEUE_SIZE, net_type='socket',
max_queue_size=MAX_QUEUE_SIZE, net_type='tensorpipe',
num_worker_threads=1):
"""Initialize DGL's distributed module
......@@ -201,9 +201,9 @@ def initialize(ip_config, num_servers=1, num_workers=0,
Note that the 20 GB is just an upper-bound and DGL uses zero-copy and
it will not allocate 20GB memory at once.
net_type : str, optional
Networking type. Currently the only valid option is ``'socket'``.
Networking type. Valid options are: ``'socket'``, ``'tensorpipe'``.
Default: ``'socket'``
Default: ``'tensorpipe'``
num_worker_threads: int
The number of threads in a worker process.
......@@ -235,7 +235,8 @@ def initialize(ip_config, num_servers=1, num_workers=0,
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'),
graph_format=formats,
keep_alive=keep_alive)
keep_alive=keep_alive,
net_type=net_type)
serv.start()
sys.exit()
else:
......
......@@ -311,10 +311,13 @@ class DistGraphServer(KVServer):
The graph formats.
keep_alive : bool
Whether to keep server alive when clients exit
net_type : str
Backend rpc type: ``'socket'`` or ``'tensorpipe'``
'''
def __init__(self, server_id, ip_config, num_servers,
num_clients, part_config, disable_shared_mem=False,
graph_format=('csc', 'coo'), keep_alive=False):
graph_format=('csc', 'coo'), keep_alive=False,
net_type='tensorpipe'):
super(DistGraphServer, self).__init__(server_id=server_id,
ip_config=ip_config,
num_servers=num_servers,
......@@ -322,6 +325,7 @@ class DistGraphServer(KVServer):
self.ip_config = ip_config
self.num_servers = num_servers
self.keep_alive = keep_alive
self.net_type = net_type
# Load graph partition data.
if self.is_backup_server():
# The backup server doesn't load the graph partition. It'll initialized afterwards.
......@@ -376,7 +380,9 @@ class DistGraphServer(KVServer):
start_server(server_id=self.server_id,
ip_config=self.ip_config,
num_servers=self.num_servers,
num_clients=self.num_clients, server_state=server_state)
num_clients=self.num_clients,
server_state=server_state,
net_type=self.net_type)
class DistGraph:
'''The class for accessing a distributed graph.
......
......@@ -15,7 +15,7 @@ from .. import backend as F
__all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \
'receiver_wait', 'connect_receiver', 'read_ip_config', 'get_group_id', \
'wait_for_senders', 'connect_receiver', 'read_ip_config', 'get_group_id', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
......@@ -112,7 +112,7 @@ def create_sender(max_queue_size, net_type):
max_queue_size : int
Maximal size (bytes) of network queue buffer.
net_type : str
Networking type. Current options are: 'socket'.
Networking type. Current options are: 'socket', 'tensorpipe'.
"""
max_thread_count = int(os.getenv('DGL_SOCKET_MAX_THREAD_COUNT', '0'))
_CAPI_DGLRPCCreateSender(int(max_queue_size), net_type, max_thread_count)
......@@ -125,7 +125,7 @@ def create_receiver(max_queue_size, net_type):
max_queue_size : int
Maximal size (bytes) of network queue buffer.
net_type : str
Networking type. Current options are: 'socket'.
Networking type. Current options are: 'socket', 'tensorpipe'.
"""
max_thread_count = int(os.getenv('DGL_SOCKET_MAX_THREAD_COUNT', '0'))
_CAPI_DGLRPCCreateReceiver(int(max_queue_size), net_type, max_thread_count)
......@@ -140,7 +140,7 @@ def finalize_receiver():
"""
_CAPI_DGLRPCFinalizeReceiver()
def receiver_wait(ip_addr, port, num_senders, blocking=True):
def wait_for_senders(ip_addr, port, num_senders, blocking=True):
"""Wait all of the senders' connections.
This api will be blocked until all the senders connect to the receiver.
......@@ -156,7 +156,7 @@ def receiver_wait(ip_addr, port, num_senders, blocking=True):
blocking : bool
whether to wait blockingly
"""
_CAPI_DGLRPCReceiverWait(ip_addr, int(port), int(num_senders), blocking)
_CAPI_DGLRPCWaitForSenders(ip_addr, int(port), int(num_senders), blocking)
def connect_receiver(ip_addr, port, recv_id, group_id=-1):
"""Connect to target receiver
......@@ -175,6 +175,15 @@ def connect_receiver(ip_addr, port, recv_id, group_id=-1):
raise DGLError("Invalid target id: {}".format(target_id))
return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(target_id))
def connect_receiver_finalize():
"""Finalize the action to connect to receivers. Make sure that either all connections are
successfully established or connection fails.
When "socket" network backend is in use, the function issues actual requests to receiver
sockets to establish connections.
"""
_CAPI_DGLRPCConnectReceiverFinalize()
def set_rank(rank):
"""Set the rank of this process.
......
......@@ -103,7 +103,7 @@ def get_local_usable_addr(probe_addr):
def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
net_type='socket', group_id=0):
net_type='tensorpipe', group_id=0):
"""Connect this client to server.
Parameters
......@@ -117,7 +117,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
Note that the 20 GB is just an upper-bound and DGL uses zero-copy and
it will not allocate 20GB memory at once.
net_type : str
Networking type. Current options are: 'socket'.
Networking type. Current options are: 'socket', 'tensorpipe'.
group_id : int
Indicates which group this client belongs to. Clients that are
booted together in each launch are gathered as a group and should
......@@ -129,7 +129,8 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
"""
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size
assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'.' % net_type
assert net_type in ('socket', 'tensorpipe'), \
'net_type (%s) can only be \'socket\' or \'tensorpipe\'.' % net_type
# Register some basic service
rpc.register_service(rpc.CLIENT_REGISTER,
rpc.ClientRegisterRequest,
......@@ -169,16 +170,19 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
server_port = addr[2]
while not rpc.connect_receiver(server_ip, server_port, server_id):
time.sleep(3)
rpc.connect_receiver_finalize()
# Get local usable IP address and port
ip_addr = get_local_usable_addr(server_ip)
client_ip, client_port = ip_addr.split(':')
# wait server connect back
rpc.receiver_wait(client_ip, client_port, num_servers, blocking=False)
print("Client [{}] waits on {}:{}".format(os.getpid(), client_ip, client_port))
# Register client on server
register_req = rpc.ClientRegisterRequest(ip_addr)
for server_id in range(num_servers):
rpc.send_request(server_id, register_req)
# wait server connect back
rpc.wait_for_senders(client_ip, client_port, num_servers,
blocking=net_type == 'socket')
print("Client [{}] waits on {}:{}".format(
os.getpid(), client_ip, client_port))
# recv client ID from server
res = rpc.recv_response()
rpc.set_rank(res.client_id)
......@@ -222,7 +226,7 @@ def shutdown_servers(ip_config, num_servers):
rpc.register_sig_handler()
server_namebook = rpc.read_ip_config(ip_config, num_servers)
num_servers = len(server_namebook)
rpc.create_sender(MAX_QUEUE_SIZE, 'socket')
rpc.create_sender(MAX_QUEUE_SIZE, 'tensorpipe')
# Get connected with all server nodes
for server_id, addr in server_namebook.items():
server_ip = addr[1]
......
"""Functions used by server."""
import time
import os
from ..base import DGLError
from . import rpc
from .constants import MAX_QUEUE_SIZE, SERVER_EXIT, SERVER_KEEP_ALIVE
def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
max_queue_size=MAX_QUEUE_SIZE, net_type='tensorpipe'):
"""Start DGL server, which will be shared with all the rpc services.
This is a blocking function -- it returns only when the server shutdown.
......@@ -31,14 +32,17 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
Note that the 20 GB is just an upper-bound because DGL uses zero-copy and
it will not allocate 20GB memory at once.
net_type : str
Networking type. Current options are: 'socket'.
Networking type. Current options are: ``'socket'`` or ``'tensorpipe'``.
"""
assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers
assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_clients
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size
assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'' % net_type
assert net_type in ('socket', 'tensorpipe'), \
'net_type (%s) can only be \'socket\' or \'tensorpipe\'' % net_type
if server_state.keep_alive:
assert net_type == 'tensorpipe', \
"net_type can only be 'tensorpipe' if 'keep_alive' is enabled."
print("As configured, this server will keep alive for multiple"
" client groups until force shutdown request is received.")
# Register signal handler.
......@@ -68,8 +72,9 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
# Once all the senders connect to server, server will not
# accept new sender's connection
print(
"Server is waiting for connections non-blockingly on [{}:{}]...".format(ip_addr, port))
rpc.receiver_wait(ip_addr, port, num_clients, blocking=False)
"Server is waiting for connections on [{}:{}]...".format(ip_addr, port))
rpc.wait_for_senders(ip_addr, port, num_clients,
blocking=net_type == 'socket')
rpc.set_num_client(num_clients)
recv_clients = {}
while True:
......@@ -83,11 +88,21 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
# a new client group is ready
ips.sort()
client_namebook = dict(enumerate(ips))
time.sleep(3) # wait for clients' receivers ready
max_try_times = int(os.environ.get('DGL_DIST_MAX_TRY_TIMES', 120))
for client_id, addr in client_namebook.items():
client_ip, client_port = addr.split(':')
# TODO[Rhett]: server should not be blocked endlessly.
try_times = 0
while not rpc.connect_receiver(client_ip, client_port, client_id, group_id):
try_times += 1
if try_times >= max_try_times:
raise DGLError("Failed to connect to receiver [{}:{}] after {} "
"retries. Please check availability of this target "
"receiver or change the max retry times via "
"'DGL_DIST_MAX_TRY_TIMES'.".format(
client_ip, client_port, max_try_times))
time.sleep(1)
rpc.connect_receiver_finalize()
if rpc.get_rank() == 0: # server_0 send all the IDs
for client_id, _ in client_namebook.items():
register_res = rpc.ClientRegisterResponse(client_id)
......
......@@ -250,19 +250,19 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver")
int recv_id = args[3];
network::Sender* sender = static_cast<network::Sender*>(chandle);
std::string addr;
if (sender->Type() == "socket") {
if (sender->NetType() == "socket") {
addr = StringPrintf("socket://%s:%d", ip.c_str(), port);
} else {
LOG(FATAL) << "Unknown communicator type: " << sender->Type();
LOG(FATAL) << "Unknown communicator type: " << sender->NetType();
}
sender->AddReceiver(addr.c_str(), recv_id);
sender->ConnectReceiver(addr.c_str(), recv_id);
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Sender* sender = static_cast<network::Sender*>(chandle);
if (sender->Connect() == false) {
if (sender->ConnectReceiverFinalize() == false) {
LOG(FATAL) << "Sender connection failed.";
}
});
......@@ -275,10 +275,10 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait")
int num_sender = args[3];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
std::string addr;
if (receiver->Type() == "socket") {
if (receiver->NetType() == "socket") {
addr = StringPrintf("socket://%s:%d", ip.c_str(), port);
} else {
LOG(FATAL) << "Unknown communicator type: " << receiver->Type();
LOG(FATAL) << "Unknown communicator type: " << receiver->NetType();
}
if (receiver->Wait(addr.c_str(), num_sender) == false) {
LOG(FATAL) << "Wait sender socket failed.";
......
/*!
* Copyright (c) 2022 by Contributors
* \file net_type.h
* \brief Base communicator for DGL distributed training.
*/
#ifndef DGL_RPC_NET_TYPE_H_
#define DGL_RPC_NET_TYPE_H_
#include <string>
#include "rpc_msg.h"
namespace dgl {
namespace rpc {
struct RPCBase {
/*!
* \brief Finalize Receiver
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
virtual void Finalize() = 0;
/*!
* \brief Communicator type: 'socket', 'tensorpipe', etc
*/
virtual const std::string &NetType() const = 0;
};
struct RPCSender : RPCBase {
/*!
* \brief Connect to a receiver.
*
* When there are multiple receivers to be connected, application will call `ConnectReceiver`
* for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are
* successfully established or some of them fail.
*
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID
* \return True for success and False for fail
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
virtual bool ConnectReceiver(const std::string &addr, int recv_id) = 0;
/*!
* \brief Finalize the action to connect to receivers. Make sure that either
* all connections are successfully established or connection fails.
* \return True for success and False for fail
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
virtual bool ConnectReceiverFinalize() { return true; }
/*!
* \brief Send RPCMessage to specified Receiver.
* \param msg data message
* \param recv_id receiver's ID
*/
virtual void Send(const RPCMessage &msg, int recv_id) = 0;
};
struct RPCReceiver : RPCBase {
/*!
* \brief Wait for all the Senders to connect
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0'
* \param num_sender total number of Senders
* \param blocking whether wait blockingly
* \return True for success and False for fail
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
virtual bool Wait(const std::string &addr, int num_sender,
bool blocking = true) = 0;
/*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage
*/
virtual void Recv(RPCMessage *msg) = 0;
};
} // namespace rpc
} // namespace dgl
#endif // DGL_RPC_NET_TYPE_H_
......@@ -10,6 +10,7 @@
#include <string>
#include "../net_type.h"
#include "msg_queue.h"
namespace dgl {
......@@ -23,7 +24,7 @@ namespace network {
* networking libraries such TCP socket and MPI. One Sender can connect to
* multiple receivers and it can send data to specified receiver via receiver's ID.
*/
class Sender {
class Sender : public rpc::RPCSender {
public:
/*!
* \brief Sender constructor
......@@ -40,23 +41,6 @@ class Sender {
virtual ~Sender() {}
/*!
* \brief Add receiver's address and ID to the sender's namebook
* \param addr Networking address, e.g., 'socket://127.0.0.1:50091', 'mpi://0'
* \param id receiver's ID
*
* AddReceiver() is not thread-safe and only one thread can invoke this API.
*/
virtual void AddReceiver(const char* addr, int id) = 0;
/*!
* \brief Connect with all the Receivers
* \return True for success and False for fail
*
* Connect() is not thread-safe and only one thread can invoke this API.
*/
virtual bool Connect() = 0;
/*!
* \brief Send data to specified Receiver.
* \param msg data message
......@@ -72,18 +56,6 @@ class Sender {
*/
virtual STATUS Send(Message msg, int recv_id) = 0;
/*!
* \brief Finalize Sender
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
virtual void Finalize() = 0;
/*!
* \brief Communicator type: 'socket', 'mpi', etc.
*/
virtual std::string Type() const = 0;
protected:
/*!
* \brief Size of message queue
......@@ -103,7 +75,7 @@ class Sender {
* libraries such as TCP socket and MPI. One Receiver can connect with multiple Senders
* and it can receive data from multiple Senders concurrently.
*/
class Receiver {
class Receiver : public rpc::RPCReceiver {
public:
/*!
* \brief Receiver constructor
......@@ -122,16 +94,6 @@ class Receiver {
virtual ~Receiver() {}
/*!
* \brief Wait for all the Senders to connect
* \param addr Networking address, e.g., 'socket://127.0.0.1:50051', 'mpi://0'
* \param num_sender total number of Senders
* \return True for success and False for fail
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
virtual bool Wait(const char* addr, int num_sender) = 0;
/*!
* \brief Recv data from Sender
* \param msg pointer of data message
......@@ -158,18 +120,6 @@ class Receiver {
*/
virtual STATUS RecvFrom(Message* msg, int send_id) = 0;
/*!
* \brief Finalize Receiver
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
virtual void Finalize() = 0;
/*!
* \brief Communicator type: 'socket', 'mpi', etc
*/
virtual std::string Type() const = 0;
protected:
/*!
* \brief Size of message queue
......
......@@ -27,8 +27,7 @@ namespace network {
/////////////////////////////////////// SocketSender ///////////////////////////////////////////
void SocketSender::AddReceiver(const char* addr, int recv_id) {
CHECK_NOTNULL(addr);
bool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) {
if (recv_id < 0) {
LOG(FATAL) << "recv_id cannot be a negative number.";
}
......@@ -36,25 +35,27 @@ void SocketSender::AddReceiver(const char* addr, int recv_id) {
std::vector<std::string> ip_and_port;
SplitStringUsing(addr, "//", &substring);
// Check address format
if (substring[0] != "socket:" || substring.size() != 2) {
if (substring[0] != "tcp:" || substring.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. ";
<< "e.g, 'tcp://127.0.0.1:50051'. ";
}
// Get IP and port
SplitStringUsing(substring[1], ":", &ip_and_port);
if (ip_and_port.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. ";
<< "e.g, 'tcp://127.0.0.1:50051'. ";
}
IPAddr address;
address.ip = ip_and_port[0];
address.port = std::stoi(ip_and_port[1]);
receiver_addrs_[recv_id] = address;
return true;
}
bool SocketSender::Connect() {
bool SocketSender::ConnectReceiverFinalize() {
// Create N sockets for Receiver
int receiver_count = static_cast<int>(receiver_addrs_.size());
if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) {
......@@ -79,11 +80,7 @@ bool SocketSender::Connect() {
LOG(INFO) << "Try to connect to: " << ip << ":" << port;
}
try_count++;
#ifdef _WIN32
Sleep(5);
#else // !_WIN32
sleep(5);
#endif // _WIN32
std::this_thread::sleep_for(std::chrono::seconds(5));
}
}
if (bo == false) {
......@@ -103,6 +100,34 @@ bool SocketSender::Connect() {
return true;
}
void SocketSender::Send(const rpc::RPCMessage& msg, int recv_id) {
std::shared_ptr<std::string> zerocopy_blob(new std::string());
StreamWithBuffer zc_write_strm(zerocopy_blob.get(), true);
zc_write_strm.Write(msg);
int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
zerocopy_blob->append(reinterpret_cast<char*>(&nonempty_ndarray_count),
sizeof(int32_t));
Message rpc_meta_msg;
rpc_meta_msg.data = const_cast<char*>(zerocopy_blob->data());
rpc_meta_msg.size = zerocopy_blob->size();
rpc_meta_msg.deallocator = [zerocopy_blob](Message*) {};
CHECK_EQ(Send(
rpc_meta_msg, recv_id), ADD_SUCCESS);
// send real ndarray data
for (auto ptr : zc_write_strm.buffer_list()) {
Message ndarray_data_msg;
ndarray_data_msg.data = reinterpret_cast<char*>(ptr.data);
if (ptr.size == 0) {
LOG(FATAL) << "Cannot send a empty NDArray.";
}
ndarray_data_msg.size = ptr.size;
NDArray tensor = ptr.tensor;
ndarray_data_msg.deallocator = [tensor](Message*) {};
CHECK_EQ(Send(
ndarray_data_msg, recv_id), ADD_SUCCESS);
}
}
STATUS SocketSender::Send(Message msg, int recv_id) {
CHECK_NOTNULL(msg.data);
CHECK_GT(msg.size, 0);
......@@ -119,11 +144,7 @@ void SocketSender::Finalize() {
// wait until queue is empty
auto& mq = msg_queue_[i];
while (mq->Empty() == false) {
#ifdef _WIN32
// just loop
#else // !_WIN32
usleep(1000);
#endif // _WIN32
std::this_thread::sleep_for(std::chrono::seconds(1));
}
// All queues have only one producer, which is main thread, so
// the producerID argument here should be zero.
......@@ -185,25 +206,24 @@ void SocketSender::SendLoop(
}
/////////////////////////////////////// SocketReceiver ///////////////////////////////////////////
bool SocketReceiver::Wait(const char* addr, int num_sender) {
CHECK_NOTNULL(addr);
bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking) {
CHECK_GT(num_sender, 0);
CHECK_EQ(blocking, true);
std::vector<std::string> substring;
std::vector<std::string> ip_and_port;
SplitStringUsing(addr, "//", &substring);
// Check address format
if (substring[0] != "socket:" || substring.size() != 2) {
if (substring[0] != "tcp:" || substring.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. ";
<< "e.g, 'tcp://127.0.0.1:50051'. ";
}
// Get IP and port
SplitStringUsing(substring[1], ":", &ip_and_port);
if (ip_and_port.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. ";
<< "e.g, 'tcp://127.0.0.1:50051'. ";
}
std::string ip = ip_and_port[0];
int port = stoi(ip_and_port[1]);
......@@ -255,6 +275,26 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) {
return true;
}
void SocketReceiver::Recv(rpc::RPCMessage* msg) {
Message rpc_meta_msg;
int send_id;
CHECK_EQ(Recv(
&rpc_meta_msg, &send_id), REMOVE_SUCCESS);
char* count_ptr = rpc_meta_msg.data+rpc_meta_msg.size-sizeof(int32_t);
int32_t nonempty_ndarray_count = *(reinterpret_cast<int32_t*>(count_ptr));
// Recv real ndarray data
std::vector<void*> buffer_list(nonempty_ndarray_count);
for (int i = 0; i < nonempty_ndarray_count; ++i) {
Message ndarray_data_msg;
CHECK_EQ(RecvFrom(
&ndarray_data_msg, send_id), REMOVE_SUCCESS);
buffer_list[i] = ndarray_data_msg.data;
}
StreamWithBuffer zc_read_strm(rpc_meta_msg.data, rpc_meta_msg.size-sizeof(int32_t), buffer_list);
zc_read_strm.Read(msg);
rpc_meta_msg.deallocator(&rpc_meta_msg);
}
STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
// queue_sem_ is a semaphore indicating how many elements in multiple
// message queues.
......@@ -288,11 +328,7 @@ void SocketReceiver::Finalize() {
for (auto& mq : msg_queue_) {
// wait until queue is empty
while (mq.second->Empty() == false) {
#ifdef _WIN32
// just loop
#else // !_WIN32
usleep(1000);
#endif // _WIN32
std::this_thread::sleep_for(std::chrono::seconds(1));
}
mq.second->SignalFinished(mq.first);
}
......
......@@ -49,21 +49,48 @@ class SocketSender : public Sender {
: Sender(queue_size, max_thread_count) {}
/*!
* \brief Add receiver's address and ID to the sender's namebook
* \param addr Networking address, e.g., 'socket://127.0.0.1:50091', 'mpi://0'
* \param id receiver's ID
* \brief Connect to a receiver.
*
* When there are multiple receivers to be connected, application will call `ConnectReceiver`
* for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are
* successfully established or some of them fail.
*
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID
* \return True for success and False for fail
*
* AddReceiver() is not thread-safe and only one thread can invoke this API.
* The function is *not* thread-safe; only one thread can invoke this API.
*/
void AddReceiver(const char* addr, int recv_id);
bool ConnectReceiver(const std::string& addr, int recv_id) override;
/*!
* \brief Connect with all the Receivers
* \brief Finalize the action to connect to receivers. Make sure that either
* all connections are successfully established or connection fails.
* \return True for success and False for fail
*
* Connect() is not thread-safe and only one thread can invoke this API.
* The function is *not* thread-safe; only one thread can invoke this API.
*/
bool ConnectReceiverFinalize() override;
/*!
* \brief Send RPCMessage to specified Receiver.
* \param msg data message
* \param recv_id receiver's ID
*/
void Send(const rpc::RPCMessage& msg, int recv_id) override;
/*!
* \brief Finalize TPSender
*/
bool Connect();
void Finalize() override;
/*!
* \brief Communicator type: 'socket'
*/
const std::string &NetType() const override {
static const std::string net_type = "socket";
return net_type;
}
/*!
* \brief Send data to specified Receiver. Actually pushing message to message queue.
......@@ -78,19 +105,7 @@ class SocketSender : public Sender {
* (4) Messages sent to the same receiver are guaranteed to be received in the same order.
* There is no guarantee for messages sent to different receivers.
*/
STATUS Send(Message msg, int recv_id);
/*!
* \brief Finalize SocketSender
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
void Finalize();
/*!
* \brief Communicator type: 'socket'
*/
inline std::string Type() const { return std::string("socket"); }
STATUS Send(Message msg, int recv_id) override;
private:
/*!
......@@ -145,13 +160,21 @@ class SocketReceiver : public Receiver {
/*!
* \brief Wait for all the Senders to connect
* \param addr Networking address, e.g., 'socket://127.0.0.1:50051', 'mpi://0'
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0'
* \param num_sender total number of Senders
* \param blocking whether wait blockingly
* \return True for success and False for fail
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
bool Wait(const char* addr, int num_sender);
bool Wait(const std::string &addr, int num_sender,
bool blocking = true) override;
/*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage
*/
void Recv(rpc::RPCMessage* msg) override;
/*!
* \brief Recv data from Sender. Actually removing data from msg_queue.
......@@ -164,7 +187,7 @@ class SocketReceiver : public Receiver {
* (2) The Recv() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
*/
STATUS Recv(Message* msg, int* send_id);
STATUS Recv(Message* msg, int* send_id) override;
/*!
* \brief Recv data from a specified Sender. Actually removing data from msg_queue.
......@@ -177,19 +200,22 @@ class SocketReceiver : public Receiver {
* (2) The RecvFrom() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
*/
STATUS RecvFrom(Message* msg, int send_id);
STATUS RecvFrom(Message* msg, int send_id) override;
/*!
* \brief Finalize SocketReceiver
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
void Finalize();
void Finalize() override;
/*!
* \brief Communicator type: 'socket'
*/
inline std::string Type() const { return std::string("socket"); }
const std::string &NetType() const override {
static const std::string net_type = "socket";
return net_type;
}
private:
struct RecvContext {
......
......@@ -114,18 +114,38 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
.set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0];
std::string type = args[1];
InitGlobalTpContext();
RPCContext::getInstance()->sender =
std::make_shared<TPSender>(RPCContext::getInstance()->ctx);
int max_thread_count = args[2];
if (type == "tensorpipe") {
InitGlobalTpContext();
RPCContext::getInstance()->sender.reset(
new TPSender(RPCContext::getInstance()->ctx));
} else if (type == "socket") {
RPCContext::getInstance()->sender.reset(
new network::SocketSender(msg_queue_size, max_thread_count));
} else {
LOG(FATAL) << "Unknown communicator type for rpc sender: " << type;
}
LOG(INFO) << "Sender with NetType~"
<< RPCContext::getInstance()->sender->NetType() << " is created.";
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
.set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0];
std::string type = args[1];
InitGlobalTpContext();
RPCContext::getInstance()->receiver =
std::make_shared<TPReceiver>(RPCContext::getInstance()->ctx);
int max_thread_count = args[2];
if (type == "tensorpipe") {
InitGlobalTpContext();
RPCContext::getInstance()->receiver.reset(
new TPReceiver(RPCContext::getInstance()->ctx));
} else if (type == "socket") {
RPCContext::getInstance()->receiver.reset(
new network::SocketReceiver(msg_queue_size, max_thread_count));
} else {
LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type;
}
LOG(INFO) << "Receiver with NetType~"
<< RPCContext::getInstance()->receiver->NetType() << " is created.";
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
......@@ -138,7 +158,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
RPCContext::getInstance()->receiver->Finalize();
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReceiverWait")
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCWaitForSenders")
.set_body([](DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0];
int port = args[1];
......@@ -161,6 +181,11 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
*rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
RPCContext::getInstance()->sender->ConnectReceiverFinalize();
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t rank = args[0];
......
......@@ -19,7 +19,9 @@
#include <unordered_map>
#include "./rpc_msg.h"
#include "./tensorpipe/tp_communicator.h"
#include "net_type.h"
#include "network/socket_communicator.h"
#include "tensorpipe/tp_communicator.h"
#include "./network/common.h"
#include "./server_state.h"
......@@ -79,12 +81,12 @@ struct RPCContext {
/*!
* \brief Sender communicator.
*/
std::shared_ptr<TPSender> sender;
std::shared_ptr<RPCSender> sender;
/*!
* \brief Receiver communicator.
*/
std::shared_ptr<TPReceiver> receiver;
std::shared_ptr<RPCReceiver> receiver;
/*!
* \brief Tensorpipe global context
......@@ -194,8 +196,4 @@ RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout = 0);
} // namespace rpc
} // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::rpc::RPCMessage, true);
} // namespace dmlc
#endif // DGL_RPC_RPC_H_
......@@ -7,6 +7,8 @@
#define DGL_RPC_RPC_MSG_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/zerocopy_serializer.h>
#include <string>
#include <vector>
......@@ -70,4 +72,9 @@ DGL_DEFINE_OBJECT_REF(RPCMessageRef, RPCMessage);
} // namespace rpc
} // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::rpc::RPCMessage, true);
} // namespace dmlc
#endif // DGL_RPC_RPC_MSG_H_
......@@ -17,12 +17,11 @@
#include <vector>
#include <atomic>
#include "./queue.h"
#include "../net_type.h"
namespace dgl {
namespace rpc {
class RPCMessage;
typedef Queue<RPCMessage> RPCMessageQueue;
/*!
......@@ -30,7 +29,7 @@ typedef Queue<RPCMessage> RPCMessageQueue;
*
* TPSender is the communicator implemented by tcp socket.
*/
class TPSender {
class TPSender : public RPCSender {
public:
/*!
* \brief Sender constructor
......@@ -47,30 +46,39 @@ class TPSender {
~TPSender() { Finalize(); }
/*!
* \brief Connect to receiver with address and ID
* \brief Connect to a receiver.
*
* When there are multiple receivers to be connected, application will call `ConnectReceiver`
* for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are
* successfully established or some of them fail.
*
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID
* \return True for success and False for fail
*
* ConnectReceiver() is not thread-safe and only one thread can invoke this API.
* The function is *not* thread-safe; only one thread can invoke this API.
*/
bool ConnectReceiver(const std::string& addr, int recv_id);
bool ConnectReceiver(const std::string& addr, int recv_id) override;
/*!
* \brief Send RPCMessage to specified Receiver.
* \param msg data message \param recv_id receiver's ID
* \param msg data message
* \param recv_id receiver's ID
*/
void Send(const RPCMessage& msg, int recv_id);
void Send(const RPCMessage& msg, int recv_id) override;
/*!
* \brief Finalize TPSender
*/
void Finalize();
void Finalize() override;
/*!
* \brief Communicator type: 'tp'
*/
inline std::string Type() const { return std::string("tp"); }
const std::string &NetType() const override {
static const std::string net_type = "tensorpipe";
return net_type;
}
private:
/*!
......@@ -95,7 +103,7 @@ class TPSender {
*
* Tensorpipe Receiver is the communicator implemented by tcp socket.
*/
class TPReceiver {
class TPReceiver : public RPCReceiver {
public:
/*!
* \brief Receiver constructor
......@@ -121,33 +129,29 @@ class TPReceiver {
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
bool Wait(const std::string &addr, int num_sender, bool blocking = true);
bool Wait(const std::string &addr, int num_sender,
bool blocking = true) override;
/*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage
* \param send_id which sender current msg comes from
* \return Status code
*
* (1) The Recv() API is blocking, which will not
* return until getting data from message queue.
* (2) The Recv() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function
* returns.
*/
void Recv(RPCMessage* msg);
void Recv(RPCMessage* msg) override;
/*!
* \brief Finalize SocketReceiver
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
void Finalize();
void Finalize() override;
/*!
* \brief Communicator type: 'tp' (tensorpipe)
*/
inline std::string Type() const { return std::string("tp"); }
const std::string &NetType() const override {
static const std::string net_type = "tensorpipe";
return net_type;
}
/*!
* \brief Issue a receive request on pipe, and push the result into queue
......
......@@ -34,9 +34,9 @@ const int kNumReceiver = 3;
const int kNumMessage = 10;
const char* ip_addr[] = {
"socket://127.0.0.1:50091",
"socket://127.0.0.1:50092",
"socket://127.0.0.1:50093"
"tcp://127.0.0.1:50091",
"tcp://127.0.0.1:50092",
"tcp://127.0.0.1:50093"
};
static void start_client();
......@@ -64,9 +64,9 @@ TEST(SocketCommunicatorTest, SendAndRecv) {
void start_client() {
SocketSender sender(kQueueSize, kThreadNum);
for (int i = 0; i < kNumReceiver; ++i) {
sender.AddReceiver(ip_addr[i], i);
sender.ConnectReceiver(ip_addr[i], i);
}
sender.Connect();
sender.ConnectReceiverFinalize();
for (int i = 0; i < kNumMessage; ++i) {
for (int n = 0; n < kNumReceiver; ++n) {
char* str_data = new char[9];
......@@ -140,7 +140,7 @@ TEST(SocketCommunicatorTest, SendAndRecv) {
srand((unsigned)time(NULL));
int port = (rand() % (5000-3000+1))+ 3000;
std::string ip_addr = "socket://127.0.0.1:" + std::to_string(port);
std::string ip_addr = "tcp://127.0.0.1:" + std::to_string(port);
std::ofstream out("addr.txt");
out << ip_addr;
out.close();
......@@ -170,8 +170,8 @@ static void start_client() {
std::istreambuf_iterator<char>());
t.close();
SocketSender sender(kQueueSize, kThreadNum);
sender.AddReceiver(ip_addr.c_str(), 0);
sender.Connect();
sender.ConnectReceiver(ip_addr.c_str(), 0);
sender.ConnectReceiverFinalize();
char* str_data = new char[9];
memcpy(str_data, "123456789", 9);
Message msg = {str_data, 9};
......
......@@ -83,7 +83,7 @@ class HelloRequest(dgl.distributed.Request):
res = HelloResponse(self.hello_str, self.integer, new_tensor)
return res
def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_servers=1):
def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_servers=1, net_type='tensorpipe'):
print("Sleep 1 seconds to test client re-connect.")
time.sleep(1)
server_state = dgl.distributed.ServerState(
......@@ -95,12 +95,13 @@ def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_serv
ip_config=ip_config,
num_servers=num_servers,
num_clients=num_clients,
server_state=server_state)
server_state=server_state,
net_type=net_type)
def start_client(ip_config, group_id=0, num_servers=1):
def start_client(ip_config, group_id=0, num_servers=1, net_type='tensorpipe'):
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.connect_to_server(
ip_config=ip_config, num_servers=num_servers, group_id=group_id)
ip_config=ip_config, num_servers=num_servers, group_id=group_id, net_type=net_type)
req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
# test send and recv
dgl.distributed.send_request(0, req)
......@@ -183,16 +184,18 @@ def test_rpc():
pclient.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client():
@pytest.mark.parametrize("net_type", ['socket', 'tensorpipe'])
def test_multi_client(net_type):
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
generate_ip_config("rpc_ip_config_mul_client.txt", 1, 1)
ip_config = "rpc_ip_config_mul_client.txt"
generate_ip_config(ip_config, 1, 1)
ctx = mp.get_context('spawn')
num_clients = 20
pserver = ctx.Process(target=start_server, args=(num_clients, "rpc_ip_config_mul_client.txt"))
pserver = ctx.Process(target=start_server, args=(num_clients, ip_config, 0, False, 1, net_type))
pclient_list = []
for i in range(num_clients):
pclient = ctx.Process(target=start_client, args=("rpc_ip_config_mul_client.txt",))
pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1, net_type))
pclient_list.append(pclient)
pserver.start()
for i in range(num_clients):
......@@ -280,5 +283,6 @@ if __name__ == '__main__':
test_serialize()
test_rpc_msg()
test_rpc()
test_multi_client()
test_multi_client('socket')
test_multi_client('tesnsorpipe')
test_multi_thread_rpc()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment