"model/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "b078dd157cbca4ba31ce49128b0d8e1e4da99b39"
Unverified Commit cac3720b authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] enable time out when fetching msg (#4043)

* [ist] enable time out when fetching msg

* fix lint error

* minor refinements

* improve minor log

* fix dist test

* fix timeout issue in tensorpipe
parent 2de80dde
...@@ -676,18 +676,17 @@ def recv_request(timeout=0): ...@@ -676,18 +676,17 @@ def recv_request(timeout=0):
req : request req : request
One request received from the target, or None if it times out. One request received from the target, or None if it times out.
client_id : int client_id : int
Client' ID received from the target. Client' ID received from the target, or -1 if it times out.
group_id : int group_id : int
Group' ID received from the target. Group' ID received from the target, or -1 if it times out.
Raises Raises
------ ------
ConnectionError if there is any problem with the connection. ConnectionError if there is any problem with the connection.
""" """
# TODO(chao): handle timeout
msg = recv_rpc_message(timeout) msg = recv_rpc_message(timeout)
if msg is None: if msg is None:
return None return None, -1, -1
set_msg_seq(msg.msg_seq) set_msg_seq(msg.msg_seq)
req_cls, _ = SERVICE_ID_TO_PROPERTY[msg.service_id] req_cls, _ = SERVICE_ID_TO_PROPERTY[msg.service_id]
if req_cls is None: if req_cls is None:
...@@ -721,7 +720,6 @@ def recv_response(timeout=0): ...@@ -721,7 +720,6 @@ def recv_response(timeout=0):
------ ------
ConnectionError if there is any problem with the connection. ConnectionError if there is any problem with the connection.
""" """
# TODO(chao): handle timeout
msg = recv_rpc_message(timeout) msg = recv_rpc_message(timeout)
if msg is None: if msg is None:
return None return None
...@@ -764,7 +762,6 @@ def remote_call(target_and_requests, timeout=0): ...@@ -764,7 +762,6 @@ def remote_call(target_and_requests, timeout=0):
------ ------
ConnectionError if there is any problem with the connection. ConnectionError if there is any problem with the connection.
""" """
# TODO(chao): handle timeout
all_res = [None] * len(target_and_requests) all_res = [None] * len(target_and_requests)
msgseq2pos = {} msgseq2pos = {}
num_res = 0 num_res = 0
...@@ -787,6 +784,9 @@ def remote_call(target_and_requests, timeout=0): ...@@ -787,6 +784,9 @@ def remote_call(target_and_requests, timeout=0):
while num_res != 0: while num_res != 0:
# recv response # recv response
msg = recv_rpc_message(timeout) msg = recv_rpc_message(timeout)
if msg is None:
raise DGLError(
f"Timed out for receiving message within {timeout} milliseconds")
num_res -= 1 num_res -= 1
_, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id] _, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]
if res_cls is None: if res_cls is None:
...@@ -864,6 +864,9 @@ def recv_responses(msgseq2pos, timeout=0): ...@@ -864,6 +864,9 @@ def recv_responses(msgseq2pos, timeout=0):
while num_res != 0: while num_res != 0:
# recv response # recv response
msg = recv_rpc_message(timeout) msg = recv_rpc_message(timeout)
if msg is None:
raise DGLError(
f"Timed out for receiving message within {timeout} milliseconds")
num_res -= 1 num_res -= 1
_, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id] _, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]
if res_cls is None: if res_cls is None:
...@@ -904,7 +907,6 @@ def remote_call_to_machine(target_and_requests, timeout=0): ...@@ -904,7 +907,6 @@ def remote_call_to_machine(target_and_requests, timeout=0):
------ ------
ConnectionError if there is any problem with the connection. ConnectionError if there is any problem with the connection.
""" """
# TODO(chao): handle timeout
msgseq2pos = send_requests_to_machine(target_and_requests) msgseq2pos = send_requests_to_machine(target_and_requests)
return recv_responses(msgseq2pos, timeout) return recv_responses(msgseq2pos, timeout)
...@@ -955,8 +957,8 @@ def recv_rpc_message(timeout=0): ...@@ -955,8 +957,8 @@ def recv_rpc_message(timeout=0):
ConnectionError if there is any problem with the connection. ConnectionError if there is any problem with the connection.
""" """
msg = _CAPI_DGLRPCCreateEmptyRPCMessage() msg = _CAPI_DGLRPCCreateEmptyRPCMessage()
_CAPI_DGLRPCRecvRPCMessage(timeout, msg) status = _CAPI_DGLRPCRecvRPCMessage(timeout, msg)
return msg return msg if status == 0 else None
def client_barrier(): def client_barrier():
"""Barrier all client processes""" """Barrier all client processes"""
......
...@@ -108,7 +108,10 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ ...@@ -108,7 +108,10 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
register_res = rpc.ClientRegisterResponse(client_id) register_res = rpc.ClientRegisterResponse(client_id)
rpc.send_response(client_id, register_res, group_id) rpc.send_response(client_id, register_res, group_id)
# receive incomming client requests # receive incomming client requests
req, client_id, group_id = rpc.recv_request() timeout = 60 * 1000 # in milliseconds
req, client_id, group_id = rpc.recv_request(timeout)
if req is None:
continue
if isinstance(req, rpc.ClientRegisterRequest): if isinstance(req, rpc.ClientRegisterRequest):
if group_id not in recv_clients: if group_id not in recv_clients:
recv_clients[group_id] = [] recv_clients[group_id] = []
......
...@@ -77,8 +77,10 @@ struct RPCReceiver : RPCBase { ...@@ -77,8 +77,10 @@ struct RPCReceiver : RPCBase {
/*! /*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue. * \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage * \param msg pointer of RPCmessage
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/ */
virtual void Recv(RPCMessage *msg) = 0; virtual RPCStatus Recv(RPCMessage *msg, int timeout) = 0;
}; };
} // namespace rpc } // namespace rpc
......
...@@ -98,27 +98,25 @@ class Receiver : public rpc::RPCReceiver { ...@@ -98,27 +98,25 @@ class Receiver : public rpc::RPCReceiver {
* \brief Recv data from Sender * \brief Recv data from Sender
* \param msg pointer of data message * \param msg pointer of data message
* \param send_id which sender current msg comes from * \param send_id which sender current msg comes from
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return Status code * \return Status code
* *
* (1) The Recv() API is blocking, which will not * (1) The Recv() API is thread-safe.
* return until getting data from message queue. * (2) Memory allocated by communicator but will not own it after the function returns.
* (2) The Recv() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
*/ */
virtual STATUS Recv(Message* msg, int* send_id) = 0; virtual STATUS Recv(Message* msg, int* send_id, int timeout = 0) = 0;
/*! /*!
* \brief Recv data from a specified Sender * \brief Recv data from a specified Sender
* \param msg pointer of data message * \param msg pointer of data message
* \param send_id sender's ID * \param send_id sender's ID
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return Status code * \return Status code
* *
* (1) The RecvFrom() API is blocking, which will not * (1) The RecvFrom() API is thread-safe.
* return until getting data from message queue. * (2) Memory allocated by communicator but will not own it after the function returns.
* (2) The RecvFrom() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
*/ */
virtual STATUS RecvFrom(Message* msg, int send_id) = 0; virtual STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) = 0;
protected: protected:
/*! /*!
......
...@@ -275,32 +275,48 @@ bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking ...@@ -275,32 +275,48 @@ bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking
return true; return true;
} }
void SocketReceiver::Recv(rpc::RPCMessage* msg) { rpc::RPCStatus SocketReceiver::Recv(rpc::RPCMessage* msg, int timeout) {
Message rpc_meta_msg; Message rpc_meta_msg;
int send_id; int send_id;
CHECK_EQ(Recv( auto status = Recv(&rpc_meta_msg, &send_id, timeout);
&rpc_meta_msg, &send_id), REMOVE_SUCCESS); if (status == QUEUE_EMPTY) {
DLOG(WARNING) << "Timed out when trying to receive rpc meta data after "
<< timeout << " milliseconds.";
return rpc::kRPCTimeOut;
}
CHECK_EQ(status, REMOVE_SUCCESS);
char* count_ptr = rpc_meta_msg.data+rpc_meta_msg.size-sizeof(int32_t); char* count_ptr = rpc_meta_msg.data+rpc_meta_msg.size-sizeof(int32_t);
int32_t nonempty_ndarray_count = *(reinterpret_cast<int32_t*>(count_ptr)); int32_t nonempty_ndarray_count = *(reinterpret_cast<int32_t*>(count_ptr));
// Recv real ndarray data // Recv real ndarray data
std::vector<void*> buffer_list(nonempty_ndarray_count); std::vector<void*> buffer_list(nonempty_ndarray_count);
for (int i = 0; i < nonempty_ndarray_count; ++i) { for (int i = 0; i < nonempty_ndarray_count; ++i) {
Message ndarray_data_msg; Message ndarray_data_msg;
CHECK_EQ(RecvFrom( status = RecvFrom(&ndarray_data_msg, send_id, timeout);
&ndarray_data_msg, send_id), REMOVE_SUCCESS); if (status == QUEUE_EMPTY) {
// As we cannot handle this timeout for now, let's treat it as fatal
// error.
LOG(FATAL) << "Timed out when trying to receive rpc ndarray data after "
<< timeout << " milliseconds.";
return rpc::kRPCTimeOut;
}
CHECK_EQ(status, REMOVE_SUCCESS);
buffer_list[i] = ndarray_data_msg.data; buffer_list[i] = ndarray_data_msg.data;
} }
StreamWithBuffer zc_read_strm(rpc_meta_msg.data, rpc_meta_msg.size-sizeof(int32_t), buffer_list); StreamWithBuffer zc_read_strm(rpc_meta_msg.data, rpc_meta_msg.size-sizeof(int32_t), buffer_list);
zc_read_strm.Read(msg); zc_read_strm.Read(msg);
rpc_meta_msg.deallocator(&rpc_meta_msg); rpc_meta_msg.deallocator(&rpc_meta_msg);
return rpc::kRPCSuccess;
} }
STATUS SocketReceiver::Recv(Message* msg, int* send_id) { STATUS SocketReceiver::Recv(Message* msg, int* send_id, int timeout) {
// queue_sem_ is a semaphore indicating how many elements in multiple // queue_sem_ is a semaphore indicating how many elements in multiple
// message queues. // message queues.
// When calling queue_sem_.Wait(), this Recv will be suspended until // When calling queue_sem_.Wait(), this Recv will be suspended until
// queue_sem_ > 0, decrease queue_sem_ by 1, then start to fetch a message. // queue_sem_ > 0 or specified timeout expires, decrease queue_sem_ by 1,
queue_sem_.Wait(); // then start to fetch a message.
if (!queue_sem_.TimedWait(timeout)) {
return QUEUE_EMPTY;
}
for (;;) { for (;;) {
for (; mq_iter_ != msg_queue_.end(); ++mq_iter_) { for (; mq_iter_ != msg_queue_.end(); ++mq_iter_) {
STATUS code = mq_iter_->second->Remove(msg, false); STATUS code = mq_iter_->second->Remove(msg, false);
...@@ -314,11 +330,16 @@ STATUS SocketReceiver::Recv(Message* msg, int* send_id) { ...@@ -314,11 +330,16 @@ STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
} }
mq_iter_ = msg_queue_.begin(); mq_iter_ = msg_queue_.begin();
} }
LOG(ERROR)
<< "Failed to remove message from queue due to unexpected queue status.";
return QUEUE_CLOSE;
} }
STATUS SocketReceiver::RecvFrom(Message* msg, int send_id) { STATUS SocketReceiver::RecvFrom(Message* msg, int send_id, int timeout) {
// Get message from specified message queue // Get message from specified message queue
queue_sem_.Wait(); if (!queue_sem_.TimedWait(timeout)) {
return QUEUE_EMPTY;
}
STATUS code = msg_queue_[send_id]->Remove(msg); STATUS code = msg_queue_[send_id]->Remove(msg);
return code; return code;
} }
......
...@@ -172,34 +172,34 @@ class SocketReceiver : public Receiver { ...@@ -172,34 +172,34 @@ class SocketReceiver : public Receiver {
/*! /*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue. * \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage * \param msg pointer of RPCmessage
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/ */
void Recv(rpc::RPCMessage* msg) override; rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout) override;
/*! /*!
* \brief Recv data from Sender. Actually removing data from msg_queue. * \brief Recv data from Sender. Actually removing data from msg_queue.
* \param msg pointer of data message * \param msg pointer of data message
* \param send_id which sender current msg comes from * \param send_id which sender current msg comes from
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return Status code * \return Status code
* *
* (1) The Recv() API is blocking, which will not * (1) The Recv() API is thread-safe.
* return until getting data from message queue. * (2) Memory allocated by communicator but will not own it after the function returns.
* (2) The Recv() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
*/ */
STATUS Recv(Message* msg, int* send_id) override; STATUS Recv(Message* msg, int* send_id, int timeout = 0) override;
/*! /*!
* \brief Recv data from a specified Sender. Actually removing data from msg_queue. * \brief Recv data from a specified Sender. Actually removing data from msg_queue.
* \param msg pointer of data message * \param msg pointer of data message
* \param send_id sender's ID * \param send_id sender's ID
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return Status code * \return Status code
* *
* (1) The RecvFrom() API is blocking, which will not * (1) The RecvFrom() API is thread-safe.
* return until getting data from message queue. * (2) Memory allocated by communicator but will not own it after the function returns.
* (2) The RecvFrom() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
*/ */
STATUS RecvFrom(Message* msg, int send_id) override; STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) override;
/*! /*!
* \brief Finalize SocketReceiver * \brief Finalize SocketReceiver
......
...@@ -69,10 +69,22 @@ RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) { ...@@ -69,10 +69,22 @@ RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {
} }
RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) { RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
// ignore timeout now static constexpr int32_t retry_timeout = 5 * 1000; // milliseconds
CHECK_EQ(timeout, 0) << "rpc cannot support timeout now."; RPCStatus status;
RPCContext::getInstance()->receiver->Recv(msg); const int32_t real_timeout = timeout == 0 ? retry_timeout : timeout;
return kRPCSuccess; do {
status = RPCContext::getInstance()->receiver->Recv(msg, real_timeout);
if (status == kRPCTimeOut) {
static const std::string log_str = [real_timeout, timeout]() {
std::ostringstream oss;
oss << "Recv RPCMessage timeout in " << real_timeout << " ms."
<< (timeout == 0 ? " Retrying ..." : "");
return oss.str();
}();
DLOG(WARNING) << log_str;
}
} while (timeout == 0 && status == kRPCTimeOut);
return status;
} }
void InitGlobalTpContext() { void InitGlobalTpContext() {
...@@ -519,9 +531,12 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") ...@@ -519,9 +531,12 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
} }
}); });
// Recv remote message // Recv remote message
for (int i = 0; i < msg_count; ++i) { int recv_cnt = 0;
while (recv_cnt < msg_count) {
RPCMessage msg; RPCMessage msg;
RecvRPCMessage(&msg, 0); auto status = RecvRPCMessage(&msg, 0);
CHECK_EQ(status, kRPCSuccess);
++recv_cnt;
int part_id = msg.server_id / group_count; int part_id = msg.server_id / group_count;
char* data_char = static_cast<char*>(msg.tensors[0]->data); char* data_char = static_cast<char*>(msg.tensors[0]->data);
dgl_id_t id_size = remote_ids[part_id].size(); dgl_id_t id_size = remote_ids[part_id].size();
......
...@@ -158,12 +158,6 @@ struct RPCContext { ...@@ -158,12 +158,6 @@ struct RPCContext {
} }
}; };
/*! \brief RPC status flag */
enum RPCStatus {
kRPCSuccess = 0,
kRPCTimeOut,
};
/*! /*!
* \brief Send out one RPC message. * \brief Send out one RPC message.
* *
......
...@@ -70,6 +70,12 @@ struct RPCMessage : public runtime::Object { ...@@ -70,6 +70,12 @@ struct RPCMessage : public runtime::Object {
DGL_DEFINE_OBJECT_REF(RPCMessageRef, RPCMessage); DGL_DEFINE_OBJECT_REF(RPCMessageRef, RPCMessage);
/*! \brief RPC status flag */
enum RPCStatus {
kRPCSuccess = 0,
kRPCTimeOut,
};
} // namespace rpc } // namespace rpc
} // namespace dgl } // namespace dgl
......
...@@ -8,9 +8,11 @@ ...@@ -8,9 +8,11 @@
#ifndef DGL_RPC_TENSORPIPE_QUEUE_H_ #ifndef DGL_RPC_TENSORPIPE_QUEUE_H_
#define DGL_RPC_TENSORPIPE_QUEUE_H_ #define DGL_RPC_TENSORPIPE_QUEUE_H_
#include <dmlc/logging.h>
#include <condition_variable> #include <condition_variable>
#include <deque> #include <deque>
#include <mutex> #include <mutex>
#include <chrono>
namespace dgl { namespace dgl {
namespace rpc { namespace rpc {
...@@ -30,15 +32,23 @@ class Queue { ...@@ -30,15 +32,23 @@ class Queue {
cv_.notify_all(); cv_.notify_all();
} }
T pop() { bool pop(T *msg, int timeout) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
while (items_.size() == 0) { if (timeout == 0) {
cv_.wait(lock); DLOG(WARNING) << "Will wait infinitely until message is popped...";
cv_.wait(lock, [this] { return items_.size() > 0; });
} else {
if (!cv_.wait_for(lock, std::chrono::milliseconds(timeout),
[this] { return items_.size() > 0; })) {
DLOG(WARNING) << "Times out for popping message after " << timeout
<< " milliseconds.";
return false;
} }
T t(std::move(items_.front())); }
*msg = std::move(items_.front());
items_.pop_front(); items_.pop_front();
cv_.notify_all(); cv_.notify_all();
return t; return true;
} }
private: private:
......
...@@ -193,7 +193,9 @@ void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe, ...@@ -193,7 +193,9 @@ void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
}); });
} }
void TPReceiver::Recv(RPCMessage *msg) { *msg = std::move(queue_->pop()); } RPCStatus TPReceiver::Recv(RPCMessage *msg, int timeout) {
return queue_->pop(msg, timeout) ? kRPCSuccess : kRPCTimeOut;
}
} // namespace rpc } // namespace rpc
} // namespace dgl } // namespace dgl
...@@ -135,8 +135,10 @@ class TPReceiver : public RPCReceiver { ...@@ -135,8 +135,10 @@ class TPReceiver : public RPCReceiver {
/*! /*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue. * \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage * \param msg pointer of RPCmessage
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/ */
void Recv(RPCMessage* msg) override; RPCStatus Recv(RPCMessage* msg, int timeout) override;
/*! /*!
* \brief Finalize SocketReceiver * \brief Finalize SocketReceiver
......
...@@ -7,6 +7,12 @@ ...@@ -7,6 +7,12 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#ifndef _WIN32
#include <errno.h>
#include <time.h>
#include <unistd.h>
#endif
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
...@@ -23,6 +29,12 @@ void Semaphore::Wait() { ...@@ -23,6 +29,12 @@ void Semaphore::Wait() {
WaitForSingleObject(sem_, INFINITE); WaitForSingleObject(sem_, INFINITE);
} }
bool Semaphore::TimedWait(int) {
// Timed wait is not supported on WIN32.
Wait();
return true;
}
void Semaphore::Post() { void Semaphore::Post() {
ReleaseSemaphore(sem_, 1, nullptr); ReleaseSemaphore(sem_, 1, nullptr);
} }
...@@ -37,6 +49,42 @@ void Semaphore::Wait() { ...@@ -37,6 +49,42 @@ void Semaphore::Wait() {
sem_wait(&sem_); sem_wait(&sem_);
} }
bool Semaphore::TimedWait(int timeout) {
// zero timeout means wait infinitely
if (timeout == 0) {
DLOG(WARNING) << "Will wait infinitely on semaphore until posted.";
Wait();
return true;
}
timespec ts;
if (clock_gettime(CLOCK_REALTIME, &ts) != 0) {
LOG(ERROR) << "Failed to get current time via clock_gettime. Errno: "
<< errno;
return false;
}
ts.tv_sec += timeout / MILLISECONDS_PER_SECOND;
ts.tv_nsec +=
(timeout % MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND;
if (ts.tv_nsec >= NANOSECONDS_PER_SECOND) {
ts.tv_nsec -= NANOSECONDS_PER_SECOND;
++ts.tv_sec;
}
int ret = 0;
while ((ret = sem_timedwait(&sem_, &ts) != 0) && errno == EINTR) {
continue;
}
if (ret != 0) {
if (errno == ETIMEDOUT) {
DLOG(WARNING) << "sem_timedwait timed out after " << timeout
<< " milliseconds.";
} else {
LOG(ERROR) << "sem_timedwait returns unexpectedly. Errno: " << errno;
}
return false;
}
return true;
}
void Semaphore::Post() { void Semaphore::Post() {
sem_post(&sem_); sem_post(&sem_);
} }
......
...@@ -29,16 +29,27 @@ class Semaphore { ...@@ -29,16 +29,27 @@ class Semaphore {
* \brief blocking wait, decrease semaphore by 1 * \brief blocking wait, decrease semaphore by 1
*/ */
void Wait(); void Wait();
/*!
* \brief timed wait, decrease semaphore by 1 or returns if times out
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
*/
bool TimedWait(int timeout);
/*! /*!
* \brief increase semaphore by 1 * \brief increase semaphore by 1
*/ */
void Post(); void Post();
private: private:
#ifdef _WIN32 #ifdef _WIN32
HANDLE sem_; HANDLE sem_;
#else #else
sem_t sem_; sem_t sem_;
#endif #endif
enum {
MILLISECONDS_PER_SECOND = 1000,
NANOSECONDS_PER_MILLISECOND = 1000 * 1000,
NANOSECONDS_PER_SECOND = 1000 * 1000 * 1000
};
}; };
} // namespace runtime } // namespace runtime
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include <vector> #include <vector>
#include <fstream> #include <fstream>
#include <streambuf> #include <streambuf>
#include <chrono>
#include <stdlib.h> #include <stdlib.h>
#include <time.h> #include <time.h>
...@@ -62,6 +62,49 @@ TEST(SocketCommunicatorTest, SendAndRecv) { ...@@ -62,6 +62,49 @@ TEST(SocketCommunicatorTest, SendAndRecv) {
} }
} }
TEST(SocketCommunicatorTest, SendAndRecvTimeout) {
std::atomic_bool stop{false};
// start 1 client, connect to 1 server, send 2 messsage
auto client = std::thread([&stop]() {
SocketSender sender(kQueueSize, kThreadNum);
sender.ConnectReceiver(ip_addr[0], 0);
sender.ConnectReceiverFinalize(kMaxTryTimes);
for (int i = 0; i < 2; ++i) {
char *str_data = new char[9];
memcpy(str_data, "123456789", 9);
Message msg = {str_data, 9};
msg.deallocator = DefaultMessageDeleter;
EXPECT_EQ(sender.Send(msg, 0), ADD_SUCCESS);
}
while (!stop) {
}
sender.Finalize();
});
// start 1 server, accept 1 client, receive 2 message
auto server = std::thread([&stop]() {
SocketReceiver receiver(kQueueSize, kThreadNum);
receiver.Wait(ip_addr[0], 1);
Message msg;
int recv_id;
// receive 1st message
EXPECT_EQ(receiver.RecvFrom(&msg, 0, 0), REMOVE_SUCCESS);
EXPECT_EQ(string(msg.data, msg.size), string("123456789"));
msg.deallocator(&msg);
// receive 2nd message
EXPECT_EQ(receiver.Recv(&msg, &recv_id, 0), REMOVE_SUCCESS);
EXPECT_EQ(string(msg.data, msg.size), string("123456789"));
msg.deallocator(&msg);
// timed out
EXPECT_EQ(receiver.RecvFrom(&msg, 0, 1000), QUEUE_EMPTY);
EXPECT_EQ(receiver.Recv(&msg, &recv_id, 1000), QUEUE_EMPTY);
stop = true;
receiver.Finalize();
});
// join
client.join();
server.join();
}
void start_client() { void start_client() {
SocketSender sender(kQueueSize, kThreadNum); SocketSender sender(kQueueSize, kThreadNum);
for (int i = 0; i < kNumReceiver; ++i) { for (int i = 0; i < kNumReceiver; ++i) {
......
...@@ -29,10 +29,12 @@ private: ...@@ -29,10 +29,12 @@ private:
} }
for (int n = 0; n < kNumSender * kNumMessage * num_machines_; ++n) { for (int n = 0; n < kNumSender * kNumMessage * num_machines_; ++n) {
dgl::rpc::RPCMessage msg; dgl::rpc::RPCMessage msg;
receiver.Recv(&msg); if (receiver.Recv(&msg, 0) != dgl::rpc::kRPCSuccess) {
LOG(FATAL) << "Failed to receive message on Server~" << id;
}
bool eq = msg.data == std::string("123456789"); bool eq = msg.data == std::string("123456789");
eq = eq && (msg.tensors.size() == kNumTensor); eq = eq && (msg.tensors.size() == kNumTensor);
for (int j = 0; j < kNumTensor; ++j) { for (int j = 0; eq && j < kNumTensor; ++j) {
eq = eq && (msg.tensors[j].ToVector<int>().size() == kSizeTensor); eq = eq && (msg.tensors[j].ToVector<int>().size() == kSizeTensor);
} }
if (!eq) { if (!eq) {
......
...@@ -83,6 +83,43 @@ class HelloRequest(dgl.distributed.Request): ...@@ -83,6 +83,43 @@ class HelloRequest(dgl.distributed.Request):
res = HelloResponse(self.hello_str, self.integer, new_tensor) res = HelloResponse(self.hello_str, self.integer, new_tensor)
return res return res
TIMEOUT_SERVICE_ID = 123456789
TIMEOUT_META = 'timeout_test'
class TimeoutResponse(dgl.distributed.Response):
def __init__(self, meta):
self.meta = meta
def __getstate__(self):
return self.meta
def __setstate__(self, state):
self.meta = state
class TimeoutRequest(dgl.distributed.Request):
def __init__(self, meta, timeout, response=True):
self.meta = meta
self.timeout = timeout
self.response = response
def __getstate__(self):
return self.meta, self.timeout, self.response
def __setstate__(self, state):
self.meta, self.timeout, self.response = state
def process_request(self, server_state):
assert self.meta == TIMEOUT_META
# convert from milliseconds to seconds
time.sleep(self.timeout/1000)
if not self.response:
return None
res = TimeoutResponse(self.meta)
return res
def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_servers=1, net_type='tensorpipe'): def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_servers=1, net_type='tensorpipe'):
print("Sleep 1 seconds to test client re-connect.") print("Sleep 1 seconds to test client re-connect.")
time.sleep(1) time.sleep(1)
...@@ -90,6 +127,8 @@ def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_serv ...@@ -90,6 +127,8 @@ def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_serv
None, local_g=None, partition_book=None, keep_alive=keep_alive) None, local_g=None, partition_book=None, keep_alive=keep_alive)
dgl.distributed.register_service( dgl.distributed.register_service(
HELLO_SERVICE_ID, HelloRequest, HelloResponse) HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.register_service(
TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse)
print("Start server {}".format(server_id)) print("Start server {}".format(server_id))
dgl.distributed.start_server(server_id=server_id, dgl.distributed.start_server(server_id=server_id,
ip_config=ip_config, ip_config=ip_config,
...@@ -134,6 +173,67 @@ def start_client(ip_config, group_id=0, num_servers=1, net_type='tensorpipe'): ...@@ -134,6 +173,67 @@ def start_client(ip_config, group_id=0, num_servers=1, net_type='tensorpipe'):
assert res.integer == INTEGER assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR)) assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
def start_client_timeout(ip_config, group_id=0, num_servers=1, net_type='tensorpipe'):
dgl.distributed.register_service(
TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse)
dgl.distributed.connect_to_server(
ip_config=ip_config, num_servers=num_servers, group_id=group_id, net_type=net_type)
timeout = 1 * 1000 # milliseconds
req = TimeoutRequest(TIMEOUT_META, timeout)
# test send and recv
dgl.distributed.send_request(0, req)
res = dgl.distributed.recv_response(timeout=int(timeout/2))
assert res is None
res = dgl.distributed.recv_response()
assert res.meta == TIMEOUT_META
# test remote_call
req = TimeoutRequest(TIMEOUT_META, timeout, response=False)
target_and_requests = []
for i in range(3):
target_and_requests.append((0, req))
expect_except = False
try:
res_list = dgl.distributed.remote_call(
target_and_requests, timeout=int(timeout/2))
except dgl.DGLError:
expect_except = True
assert expect_except
# test send_request_to_machine
req = TimeoutRequest(TIMEOUT_META, timeout)
dgl.distributed.send_request_to_machine(0, req)
res = dgl.distributed.recv_response(timeout=int(timeout/2))
assert res is None
res = dgl.distributed.recv_response()
assert res.meta == TIMEOUT_META
# test remote_call_to_machine
req = TimeoutRequest(TIMEOUT_META, timeout, response=False)
target_and_requests = []
for i in range(3):
target_and_requests.append((0, req))
expect_except = False
try:
res_list = dgl.distributed.remote_call_to_machine(
target_and_requests, timeout=int(timeout/2))
except dgl.DGLError:
expect_except = True
assert expect_except
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize("net_type", ['socket', 'tensorpipe'])
def test_rpc_timeout(net_type):
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = "rpc_ip_config.txt"
generate_ip_config(ip_config, 1, 1)
ctx = mp.get_context('spawn')
pserver = ctx.Process(target=start_server, args=(1, ip_config, 0, False, 1, net_type))
pclient = ctx.Process(target=start_client_timeout, args=(ip_config, 0, 1, net_type))
pserver.start()
pclient.start()
pserver.join()
pclient.join()
def test_serialize(): def test_serialize():
reset_envs() reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment