Unverified Commit e9fd65e9 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] send rpc messages blockingly in case of congestion (#3867)


Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 172949d4
...@@ -66,13 +66,17 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) { ...@@ -66,13 +66,17 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) {
LOG(FATAL) << "Cannot send a empty NDArray."; LOG(FATAL) << "Cannot send a empty NDArray.";
} }
} }
// Let's write blockingly in case of congestion in underlying transports.
auto done = std::make_shared<std::promise<void>>();
pipe->write(tp_msg, pipe->write(tp_msg,
[ndarray_holder, recv_id](const tensorpipe::Error &error) { [ndarray_holder, recv_id, done](const tensorpipe::Error &error) {
if (error) { if (error) {
LOG(FATAL) << "Failed to send message to " << recv_id LOG(FATAL) << "Failed to send message to " << recv_id
<< ". Details: " << error.what(); << ". Details: " << error.what();
} }
done->set_value();
}); });
done->get_future().wait();
} }
void TPSender::Finalize() { void TPSender::Finalize() {
......
...@@ -16,7 +16,7 @@ if os.name != 'nt': ...@@ -16,7 +16,7 @@ if os.name != 'nt':
INTEGER = 2 INTEGER = 2
STR = 'hello world!' STR = 'hello world!'
HELLO_SERVICE_ID = 901231 HELLO_SERVICE_ID = 901231
TENSOR = F.zeros((10, 10), F.int64, F.cpu()) TENSOR = F.zeros((1000, 1000), F.int64, F.cpu())
def foo(x, y): def foo(x, y):
assert x == 123 assert x == 123
...@@ -188,15 +188,16 @@ def test_multi_client(): ...@@ -188,15 +188,16 @@ def test_multi_client():
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
generate_ip_config("rpc_ip_config_mul_client.txt", 1, 1) generate_ip_config("rpc_ip_config_mul_client.txt", 1, 1)
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
pserver = ctx.Process(target=start_server, args=(10, "rpc_ip_config_mul_client.txt")) num_clients = 20
pserver = ctx.Process(target=start_server, args=(num_clients, "rpc_ip_config_mul_client.txt"))
pclient_list = [] pclient_list = []
for i in range(10): for i in range(num_clients):
pclient = ctx.Process(target=start_client, args=("rpc_ip_config_mul_client.txt",)) pclient = ctx.Process(target=start_client, args=("rpc_ip_config_mul_client.txt",))
pclient_list.append(pclient) pclient_list.append(pclient)
pserver.start() pserver.start()
for i in range(10): for i in range(num_clients):
pclient_list[i].start() pclient_list[i].start()
for i in range(10): for i in range(num_clients):
pclient_list[i].join() pclient_list[i].join()
pserver.join() pserver.join()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment