"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "3c8c18a0f4ba4cea56f23a4abd8f3faea36a8674"
Unverified Commit 22e218d3 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] Enable maximum try times for socket backend via DGL_DIST_MAX_T… (#3977)

* [Dist] Enable maximum try times for socket backend via DGL_DIST_MAX_TRY_TIMES

* reset env before/after test

* print log for info when trying to connect

* fix

* print log in python instead of cpp
parent 74f01405
...@@ -18,7 +18,7 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \ ...@@ -18,7 +18,7 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'wait_for_senders', 'connect_receiver', 'read_ip_config', 'get_group_id', \ 'wait_for_senders', 'connect_receiver', 'read_ip_config', 'get_group_id', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \ 'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \ 'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \ 'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', 'DistConnectError', \
'get_num_client', 'set_num_client', 'client_barrier', 'copy_data_to_shared_memory'] 'get_num_client', 'set_num_client', 'client_barrier', 'copy_data_to_shared_memory']
REQUEST_CLASS_TO_SERVICE_ID = {} REQUEST_CLASS_TO_SERVICE_ID = {}
...@@ -175,14 +175,19 @@ def connect_receiver(ip_addr, port, recv_id, group_id=-1): ...@@ -175,14 +175,19 @@ def connect_receiver(ip_addr, port, recv_id, group_id=-1):
raise DGLError("Invalid target id: {}".format(target_id)) raise DGLError("Invalid target id: {}".format(target_id))
return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(target_id)) return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(target_id))
def connect_receiver_finalize(): def connect_receiver_finalize(max_try_times):
"""Finalize the action to connect to receivers. Make sure that either all connections are """Finalize the action to connect to receivers. Make sure that either all connections are
successfully established or connection fails. successfully established or connection fails.
When "socket" network backend is in use, the function issues actual requests to receiver When "socket" network backend is in use, the function issues actual requests to receiver
sockets to establish connections. sockets to establish connections.
Parameters
----------
max_try_times : int
maximum try times
""" """
_CAPI_DGLRPCConnectReceiverFinalize() return _CAPI_DGLRPCConnectReceiverFinalize(max_try_times)
def set_rank(rank): def set_rank(rank):
"""Set the rank of this process. """Set the rank of this process.
...@@ -1231,4 +1236,21 @@ def get_client(client_id, group_id): ...@@ -1231,4 +1236,21 @@ def get_client(client_id, group_id):
""" """
return _CAPI_DGLRPCGetClient(int(client_id), int(group_id)) return _CAPI_DGLRPCGetClient(int(client_id), int(group_id))
class DistConnectError(DGLError):
"""Exception raised for errors if fail to connect peer.
Attributes
----------
kv_store : KVServer
reference for KVServer
"""
def __init__(self, max_try_times, ip='', port=''):
peer_str = "peer[{}:{}]".format(ip, port) if ip != '' else "peer"
self.message = "Failed to build conncetion with {} after {} retries. " \
"Please check network availability or increase max try " \
"times via 'DGL_DIST_MAX_TRY_TIMES'.".format(
peer_str, max_try_times)
super().__init__(self.message)
_init_api("dgl.distributed.rpc") _init_api("dgl.distributed.rpc")
...@@ -165,12 +165,21 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, ...@@ -165,12 +165,21 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
rpc.create_sender(max_queue_size, net_type) rpc.create_sender(max_queue_size, net_type)
rpc.create_receiver(max_queue_size, net_type) rpc.create_receiver(max_queue_size, net_type)
# Get connected with all server nodes # Get connected with all server nodes
max_try_times = int(os.environ.get('DGL_DIST_MAX_TRY_TIMES', 1024))
for server_id, addr in server_namebook.items(): for server_id, addr in server_namebook.items():
server_ip = addr[1] server_ip = addr[1]
server_port = addr[2] server_port = addr[2]
try_times = 0
while not rpc.connect_receiver(server_ip, server_port, server_id): while not rpc.connect_receiver(server_ip, server_port, server_id):
try_times += 1
if try_times % 200 == 0:
print("Client is trying to connect server receiver: {}:{}".format(
server_ip, server_port))
if try_times >= max_try_times:
raise rpc.DistConnectError(max_try_times, server_ip, server_port)
time.sleep(3) time.sleep(3)
rpc.connect_receiver_finalize() if not rpc.connect_receiver_finalize(max_try_times):
raise rpc.DistConnectError(max_try_times)
# Get local usable IP address and port # Get local usable IP address and port
ip_addr = get_local_usable_addr(server_ip) ip_addr = get_local_usable_addr(server_ip)
client_ip, client_port = ip_addr.split(':') client_ip, client_port = ip_addr.split(':')
......
...@@ -95,14 +95,14 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ ...@@ -95,14 +95,14 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
try_times = 0 try_times = 0
while not rpc.connect_receiver(client_ip, client_port, client_id, group_id): while not rpc.connect_receiver(client_ip, client_port, client_id, group_id):
try_times += 1 try_times += 1
if try_times % 200 == 0:
print("Server~{} is trying to connect client receiver: {}:{}".format(
server_id, client_ip, client_port))
if try_times >= max_try_times: if try_times >= max_try_times:
raise DGLError("Failed to connect to receiver [{}:{}] after {} " raise rpc.DistConnectError(max_try_times, client_ip, client_port)
"retries. Please check availability of this target "
"receiver or change the max retry times via "
"'DGL_DIST_MAX_TRY_TIMES'.".format(
client_ip, client_port, max_try_times))
time.sleep(1) time.sleep(1)
rpc.connect_receiver_finalize() if not rpc.connect_receiver_finalize(max_try_times):
raise rpc.DistConnectError(max_try_times)
if rpc.get_rank() == 0: # server_0 send all the IDs if rpc.get_rank() == 0: # server_0 send all the IDs
for client_id, _ in client_namebook.items(): for client_id, _ in client_namebook.items():
register_res = rpc.ClientRegisterResponse(client_id) register_res = rpc.ClientRegisterResponse(client_id)
......
...@@ -262,7 +262,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect") ...@@ -262,7 +262,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Sender* sender = static_cast<network::Sender*>(chandle); network::Sender* sender = static_cast<network::Sender*>(chandle);
if (sender->ConnectReceiverFinalize() == false) { const int max_try_times = 1024;
if (sender->ConnectReceiverFinalize(max_try_times) == false) {
LOG(FATAL) << "Sender connection failed."; LOG(FATAL) << "Sender connection failed.";
} }
}); });
......
...@@ -49,7 +49,9 @@ struct RPCSender : RPCBase { ...@@ -49,7 +49,9 @@ struct RPCSender : RPCBase {
* *
* The function is *not* thread-safe; only one thread can invoke this API. * The function is *not* thread-safe; only one thread can invoke this API.
*/ */
virtual bool ConnectReceiverFinalize() { return true; } virtual bool ConnectReceiverFinalize(const int max_try_times) {
return true;
}
/*! /*!
* \brief Send RPCMessage to specified Receiver. * \brief Send RPCMessage to specified Receiver.
......
...@@ -55,7 +55,7 @@ bool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) { ...@@ -55,7 +55,7 @@ bool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) {
return true; return true;
} }
bool SocketSender::ConnectReceiverFinalize() { bool SocketSender::ConnectReceiverFinalize(const int max_try_times) {
// Create N sockets for Receiver // Create N sockets for Receiver
int receiver_count = static_cast<int>(receiver_addrs_.size()); int receiver_count = static_cast<int>(receiver_addrs_.size());
if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) { if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) {
...@@ -71,16 +71,16 @@ bool SocketSender::ConnectReceiverFinalize() { ...@@ -71,16 +71,16 @@ bool SocketSender::ConnectReceiverFinalize() {
int try_count = 0; int try_count = 0;
const char* ip = r.second.ip.c_str(); const char* ip = r.second.ip.c_str();
int port = r.second.port; int port = r.second.port;
while (bo == false && try_count < kMaxTryCount) { while (bo == false && try_count < max_try_times) {
if (client_socket->Connect(ip, port)) { if (client_socket->Connect(ip, port)) {
bo = true; bo = true;
} else { } else {
if (try_count % 200 == 0 && try_count != 0) { if (try_count % 200 == 0 && try_count != 0) {
// every 1000 seconds show this message // every 600 seconds show this message
LOG(INFO) << "Try to connect to: " << ip << ":" << port; LOG(INFO) << "Trying to connect receiver: " << ip << ":" << port;
} }
try_count++; try_count++;
std::this_thread::sleep_for(std::chrono::seconds(5)); std::this_thread::sleep_for(std::chrono::seconds(3));
} }
} }
if (bo == false) { if (bo == false) {
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
namespace dgl { namespace dgl {
namespace network { namespace network {
static constexpr int kMaxTryCount = 1024; // maximal connection: 1024
static constexpr int kTimeOut = 10 * 60; // 10 minutes (in seconds) for socket timeout static constexpr int kTimeOut = 10 * 60; // 10 minutes (in seconds) for socket timeout
static constexpr int kMaxConnection = 1024; // maximal connection: 1024 static constexpr int kMaxConnection = 1024; // maximal connection: 1024
...@@ -70,7 +69,7 @@ class SocketSender : public Sender { ...@@ -70,7 +69,7 @@ class SocketSender : public Sender {
* *
* The function is *not* thread-safe; only one thread can invoke this API. * The function is *not* thread-safe; only one thread can invoke this API.
*/ */
bool ConnectReceiverFinalize() override; bool ConnectReceiverFinalize(const int max_try_times) override;
/*! /*!
* \brief Send RPCMessage to specified Receiver. * \brief Send RPCMessage to specified Receiver.
......
...@@ -183,7 +183,8 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver") ...@@ -183,7 +183,8 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
RPCContext::getInstance()->sender->ConnectReceiverFinalize(); const int max_try_times = args[0];
*rv = RPCContext::getInstance()->sender->ConnectReceiverFinalize(max_try_times);
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
......
...@@ -81,15 +81,21 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) { ...@@ -81,15 +81,21 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) {
void TPSender::Finalize() { void TPSender::Finalize() {
for (auto &&p : pipes_) { for (auto &&p : pipes_) {
p.second->close(); if (p.second) {
p.second->close();
}
} }
pipes_.clear(); pipes_.clear();
} }
void TPReceiver::Finalize() { void TPReceiver::Finalize() {
listener_->close(); if (listener_) {
listener_->close();
}
for (auto &&p : pipes_) { for (auto &&p : pipes_) {
p.second->close(); if (p.second) {
p.second->close();
}
} }
pipes_.clear(); pipes_.clear();
} }
......
...@@ -26,6 +26,7 @@ using dgl::network::DefaultMessageDeleter; ...@@ -26,6 +26,7 @@ using dgl::network::DefaultMessageDeleter;
const int64_t kQueueSize = 500 * 1024; const int64_t kQueueSize = 500 * 1024;
const int kThreadNum = 2; const int kThreadNum = 2;
const int kMaxTryTimes = 1024;
#ifndef WIN32 #ifndef WIN32
...@@ -66,7 +67,7 @@ void start_client() { ...@@ -66,7 +67,7 @@ void start_client() {
for (int i = 0; i < kNumReceiver; ++i) { for (int i = 0; i < kNumReceiver; ++i) {
sender.ConnectReceiver(ip_addr[i], i); sender.ConnectReceiver(ip_addr[i], i);
} }
sender.ConnectReceiverFinalize(); sender.ConnectReceiverFinalize(kMaxTryTimes);
for (int i = 0; i < kNumMessage; ++i) { for (int i = 0; i < kNumMessage; ++i) {
for (int n = 0; n < kNumReceiver; ++n) { for (int n = 0; n < kNumReceiver; ++n) {
char* str_data = new char[9]; char* str_data = new char[9];
...@@ -171,7 +172,7 @@ static void start_client() { ...@@ -171,7 +172,7 @@ static void start_client() {
t.close(); t.close();
SocketSender sender(kQueueSize, kThreadNum); SocketSender sender(kQueueSize, kThreadNum);
sender.ConnectReceiver(ip_addr.c_str(), 0); sender.ConnectReceiver(ip_addr.c_str(), 0);
sender.ConnectReceiverFinalize(); sender.ConnectReceiverFinalize(kMaxTryTimes);
char* str_data = new char[9]; char* str_data = new char[9];
memcpy(str_data, "123456789", 9); memcpy(str_data, "123456789", 9);
Message msg = {str_data, 9}; Message msg = {str_data, 9};
......
...@@ -278,6 +278,35 @@ def test_multi_client_groups(): ...@@ -278,6 +278,35 @@ def test_multi_client_groups():
for p in pserver_list: for p in pserver_list:
p.join() p.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize("net_type", ['socket', 'tensorpipe'])
def test_multi_client_connect(net_type):
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = "rpc_ip_config_mul_client.txt"
generate_ip_config(ip_config, 1, 1)
ctx = mp.get_context('spawn')
num_clients = 1
pserver = ctx.Process(target=start_server, args=(num_clients, ip_config, 0, False, 1, net_type))
# small max try times
os.environ['DGL_DIST_MAX_TRY_TIMES'] = '1'
expect_except = False
try:
start_client(ip_config, 0, 1, net_type)
except dgl.distributed.DistConnectError as err:
print("Expected error: {}".format(err))
expect_except = True
assert expect_except
# large max try times
os.environ['DGL_DIST_MAX_TRY_TIMES'] = '1024'
pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1, net_type))
pclient.start()
pserver.start()
pclient.join()
pserver.join()
reset_envs()
if __name__ == '__main__': if __name__ == '__main__':
test_serialize() test_serialize()
...@@ -286,3 +315,5 @@ if __name__ == '__main__': ...@@ -286,3 +315,5 @@ if __name__ == '__main__':
test_multi_client('socket') test_multi_client('socket')
test_multi_client('tesnsorpipe') test_multi_client('tesnsorpipe')
test_multi_thread_rpc() test_multi_thread_rpc()
test_multi_client_connect('socket')
test_multi_client_connect('tensorpipe')
...@@ -40,6 +40,7 @@ def generate_ip_config(file_name, num_machines, num_servers): ...@@ -40,6 +40,7 @@ def generate_ip_config(file_name, num_machines, num_servers):
def reset_envs(): def reset_envs():
"""Reset common environment variable which are set in tests. """ """Reset common environment variable which are set in tests. """
for key in ['DGL_ROLE', 'DGL_NUM_SAMPLER', 'DGL_NUM_SERVER', 'DGL_DIST_MODE', 'DGL_NUM_CLIENT']: for key in ['DGL_ROLE', 'DGL_NUM_SAMPLER', 'DGL_NUM_SERVER', \
'DGL_DIST_MODE', 'DGL_NUM_CLIENT', 'DGL_DIST_MAX_TRY_TIMES']:
if key in os.environ: if key in os.environ:
os.environ.pop(key) os.environ.pop(key)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment