Unverified Commit 22e218d3 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] Enable maximum try times for socket backend via DGL_DIST_MAX_T… (#3977)

* [Dist] Enable maximum try times for socket backend via DGL_DIST_MAX_TRY_TIMES

* reset env before/after test

* print log for info when trying to connect

* fix

* print log in python instead of cpp
parent 74f01405
......@@ -18,7 +18,7 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'wait_for_senders', 'connect_receiver', 'read_ip_config', 'get_group_id', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', 'DistConnectError', \
'get_num_client', 'set_num_client', 'client_barrier', 'copy_data_to_shared_memory']
REQUEST_CLASS_TO_SERVICE_ID = {}
......@@ -175,14 +175,19 @@ def connect_receiver(ip_addr, port, recv_id, group_id=-1):
raise DGLError("Invalid target id: {}".format(target_id))
return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(target_id))
def connect_receiver_finalize():
def connect_receiver_finalize(max_try_times):
"""Finalize the action to connect to receivers. Make sure that either all connections are
successfully established or connection fails.
When "socket" network backend is in use, the function issues actual requests to receiver
sockets to establish connections.
Parameters
----------
max_try_times : int
maximum try times
"""
_CAPI_DGLRPCConnectReceiverFinalize()
return _CAPI_DGLRPCConnectReceiverFinalize(max_try_times)
def set_rank(rank):
"""Set the rank of this process.
......@@ -1231,4 +1236,21 @@ def get_client(client_id, group_id):
"""
return _CAPI_DGLRPCGetClient(int(client_id), int(group_id))
class DistConnectError(DGLError):
"""Exception raised for errors if fail to connect peer.
Attributes
----------
kv_store : KVServer
reference for KVServer
"""
def __init__(self, max_try_times, ip='', port=''):
peer_str = "peer[{}:{}]".format(ip, port) if ip != '' else "peer"
self.message = "Failed to build conncetion with {} after {} retries. " \
"Please check network availability or increase max try " \
"times via 'DGL_DIST_MAX_TRY_TIMES'.".format(
peer_str, max_try_times)
super().__init__(self.message)
_init_api("dgl.distributed.rpc")
......@@ -165,12 +165,21 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
rpc.create_sender(max_queue_size, net_type)
rpc.create_receiver(max_queue_size, net_type)
# Get connected with all server nodes
max_try_times = int(os.environ.get('DGL_DIST_MAX_TRY_TIMES', 1024))
for server_id, addr in server_namebook.items():
server_ip = addr[1]
server_port = addr[2]
try_times = 0
while not rpc.connect_receiver(server_ip, server_port, server_id):
try_times += 1
if try_times % 200 == 0:
print("Client is trying to connect server receiver: {}:{}".format(
server_ip, server_port))
if try_times >= max_try_times:
raise rpc.DistConnectError(max_try_times, server_ip, server_port)
time.sleep(3)
rpc.connect_receiver_finalize()
if not rpc.connect_receiver_finalize(max_try_times):
raise rpc.DistConnectError(max_try_times)
# Get local usable IP address and port
ip_addr = get_local_usable_addr(server_ip)
client_ip, client_port = ip_addr.split(':')
......
......@@ -95,14 +95,14 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
try_times = 0
while not rpc.connect_receiver(client_ip, client_port, client_id, group_id):
try_times += 1
if try_times % 200 == 0:
print("Server~{} is trying to connect client receiver: {}:{}".format(
server_id, client_ip, client_port))
if try_times >= max_try_times:
raise DGLError("Failed to connect to receiver [{}:{}] after {} "
"retries. Please check availability of this target "
"receiver or change the max retry times via "
"'DGL_DIST_MAX_TRY_TIMES'.".format(
client_ip, client_port, max_try_times))
raise rpc.DistConnectError(max_try_times, client_ip, client_port)
time.sleep(1)
rpc.connect_receiver_finalize()
if not rpc.connect_receiver_finalize(max_try_times):
raise rpc.DistConnectError(max_try_times)
if rpc.get_rank() == 0: # server_0 send all the IDs
for client_id, _ in client_namebook.items():
register_res = rpc.ClientRegisterResponse(client_id)
......
......@@ -262,7 +262,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Sender* sender = static_cast<network::Sender*>(chandle);
if (sender->ConnectReceiverFinalize() == false) {
const int max_try_times = 1024;
if (sender->ConnectReceiverFinalize(max_try_times) == false) {
LOG(FATAL) << "Sender connection failed.";
}
});
......
......@@ -49,7 +49,9 @@ struct RPCSender : RPCBase {
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
virtual bool ConnectReceiverFinalize() { return true; }
virtual bool ConnectReceiverFinalize(const int max_try_times) {
return true;
}
/*!
* \brief Send RPCMessage to specified Receiver.
......
......@@ -55,7 +55,7 @@ bool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) {
return true;
}
bool SocketSender::ConnectReceiverFinalize() {
bool SocketSender::ConnectReceiverFinalize(const int max_try_times) {
// Create N sockets for Receiver
int receiver_count = static_cast<int>(receiver_addrs_.size());
if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) {
......@@ -71,16 +71,16 @@ bool SocketSender::ConnectReceiverFinalize() {
int try_count = 0;
const char* ip = r.second.ip.c_str();
int port = r.second.port;
while (bo == false && try_count < kMaxTryCount) {
while (bo == false && try_count < max_try_times) {
if (client_socket->Connect(ip, port)) {
bo = true;
} else {
if (try_count % 200 == 0 && try_count != 0) {
// every 1000 seconds show this message
LOG(INFO) << "Try to connect to: " << ip << ":" << port;
// every 600 seconds show this message
LOG(INFO) << "Trying to connect receiver: " << ip << ":" << port;
}
try_count++;
std::this_thread::sleep_for(std::chrono::seconds(5));
std::this_thread::sleep_for(std::chrono::seconds(3));
}
}
if (bo == false) {
......
......@@ -21,7 +21,6 @@
namespace dgl {
namespace network {
static constexpr int kMaxTryCount = 1024; // maximal connection: 1024
static constexpr int kTimeOut = 10 * 60; // 10 minutes (in seconds) for socket timeout
static constexpr int kMaxConnection = 1024; // maximal connection: 1024
......@@ -70,7 +69,7 @@ class SocketSender : public Sender {
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
bool ConnectReceiverFinalize() override;
bool ConnectReceiverFinalize(const int max_try_times) override;
/*!
* \brief Send RPCMessage to specified Receiver.
......
......@@ -183,7 +183,8 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
RPCContext::getInstance()->sender->ConnectReceiverFinalize();
const int max_try_times = args[0];
*rv = RPCContext::getInstance()->sender->ConnectReceiverFinalize(max_try_times);
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
......
......@@ -81,16 +81,22 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) {
void TPSender::Finalize() {
for (auto &&p : pipes_) {
if (p.second) {
p.second->close();
}
}
pipes_.clear();
}
void TPReceiver::Finalize() {
if (listener_) {
listener_->close();
}
for (auto &&p : pipes_) {
if (p.second) {
p.second->close();
}
}
pipes_.clear();
}
......
......@@ -26,6 +26,7 @@ using dgl::network::DefaultMessageDeleter;
const int64_t kQueueSize = 500 * 1024;
const int kThreadNum = 2;
const int kMaxTryTimes = 1024;
#ifndef WIN32
......@@ -66,7 +67,7 @@ void start_client() {
for (int i = 0; i < kNumReceiver; ++i) {
sender.ConnectReceiver(ip_addr[i], i);
}
sender.ConnectReceiverFinalize();
sender.ConnectReceiverFinalize(kMaxTryTimes);
for (int i = 0; i < kNumMessage; ++i) {
for (int n = 0; n < kNumReceiver; ++n) {
char* str_data = new char[9];
......@@ -171,7 +172,7 @@ static void start_client() {
t.close();
SocketSender sender(kQueueSize, kThreadNum);
sender.ConnectReceiver(ip_addr.c_str(), 0);
sender.ConnectReceiverFinalize();
sender.ConnectReceiverFinalize(kMaxTryTimes);
char* str_data = new char[9];
memcpy(str_data, "123456789", 9);
Message msg = {str_data, 9};
......
......@@ -278,6 +278,35 @@ def test_multi_client_groups():
for p in pserver_list:
p.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize("net_type", ['socket', 'tensorpipe'])
def test_multi_client_connect(net_type):
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = "rpc_ip_config_mul_client.txt"
generate_ip_config(ip_config, 1, 1)
ctx = mp.get_context('spawn')
num_clients = 1
pserver = ctx.Process(target=start_server, args=(num_clients, ip_config, 0, False, 1, net_type))
# small max try times
os.environ['DGL_DIST_MAX_TRY_TIMES'] = '1'
expect_except = False
try:
start_client(ip_config, 0, 1, net_type)
except dgl.distributed.DistConnectError as err:
print("Expected error: {}".format(err))
expect_except = True
assert expect_except
# large max try times
os.environ['DGL_DIST_MAX_TRY_TIMES'] = '1024'
pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1, net_type))
pclient.start()
pserver.start()
pclient.join()
pserver.join()
reset_envs()
if __name__ == '__main__':
test_serialize()
......@@ -286,3 +315,5 @@ if __name__ == '__main__':
test_multi_client('socket')
test_multi_client('tesnsorpipe')
test_multi_thread_rpc()
test_multi_client_connect('socket')
test_multi_client_connect('tensorpipe')
......@@ -40,6 +40,7 @@ def generate_ip_config(file_name, num_machines, num_servers):
def reset_envs():
"""Reset common environment variable which are set in tests. """
for key in ['DGL_ROLE', 'DGL_NUM_SAMPLER', 'DGL_NUM_SERVER', 'DGL_DIST_MODE', 'DGL_NUM_CLIENT']:
for key in ['DGL_ROLE', 'DGL_NUM_SAMPLER', 'DGL_NUM_SERVER', \
'DGL_DIST_MODE', 'DGL_NUM_CLIENT', 'DGL_DIST_MAX_TRY_TIMES']:
if key in os.environ:
os.environ.pop(key)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment