Unverified Commit cac3720b authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] enable time out when fetching msg (#4043)

* [ist] enable time out when fetching msg

* fix lint error

* minor refinements

* improve minor log

* fix dist test

* fix timeout issue in tensorpipe
parent 2de80dde
......@@ -676,18 +676,17 @@ def recv_request(timeout=0):
req : request
One request received from the target, or None if it times out.
client_id : int
Client' ID received from the target.
Client' ID received from the target, or -1 if it times out.
group_id : int
Group' ID received from the target.
Group' ID received from the target, or -1 if it times out.
Raises
------
ConnectionError if there is any problem with the connection.
"""
# TODO(chao): handle timeout
msg = recv_rpc_message(timeout)
if msg is None:
return None
return None, -1, -1
set_msg_seq(msg.msg_seq)
req_cls, _ = SERVICE_ID_TO_PROPERTY[msg.service_id]
if req_cls is None:
......@@ -721,7 +720,6 @@ def recv_response(timeout=0):
------
ConnectionError if there is any problem with the connection.
"""
# TODO(chao): handle timeout
msg = recv_rpc_message(timeout)
if msg is None:
return None
......@@ -764,7 +762,6 @@ def remote_call(target_and_requests, timeout=0):
------
ConnectionError if there is any problem with the connection.
"""
# TODO(chao): handle timeout
all_res = [None] * len(target_and_requests)
msgseq2pos = {}
num_res = 0
......@@ -787,6 +784,9 @@ def remote_call(target_and_requests, timeout=0):
while num_res != 0:
# recv response
msg = recv_rpc_message(timeout)
if msg is None:
raise DGLError(
f"Timed out for receiving message within {timeout} milliseconds")
num_res -= 1
_, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]
if res_cls is None:
......@@ -864,6 +864,9 @@ def recv_responses(msgseq2pos, timeout=0):
while num_res != 0:
# recv response
msg = recv_rpc_message(timeout)
if msg is None:
raise DGLError(
f"Timed out for receiving message within {timeout} milliseconds")
num_res -= 1
_, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]
if res_cls is None:
......@@ -904,7 +907,6 @@ def remote_call_to_machine(target_and_requests, timeout=0):
------
ConnectionError if there is any problem with the connection.
"""
# TODO(chao): handle timeout
msgseq2pos = send_requests_to_machine(target_and_requests)
return recv_responses(msgseq2pos, timeout)
......@@ -955,8 +957,8 @@ def recv_rpc_message(timeout=0):
ConnectionError if there is any problem with the connection.
"""
msg = _CAPI_DGLRPCCreateEmptyRPCMessage()
_CAPI_DGLRPCRecvRPCMessage(timeout, msg)
return msg
status = _CAPI_DGLRPCRecvRPCMessage(timeout, msg)
return msg if status == 0 else None
def client_barrier():
"""Barrier all client processes"""
......
......@@ -108,7 +108,10 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
register_res = rpc.ClientRegisterResponse(client_id)
rpc.send_response(client_id, register_res, group_id)
# receive incomming client requests
req, client_id, group_id = rpc.recv_request()
timeout = 60 * 1000 # in milliseconds
req, client_id, group_id = rpc.recv_request(timeout)
if req is None:
continue
if isinstance(req, rpc.ClientRegisterRequest):
if group_id not in recv_clients:
recv_clients[group_id] = []
......
......@@ -77,8 +77,10 @@ struct RPCReceiver : RPCBase {
/*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/
virtual void Recv(RPCMessage *msg) = 0;
virtual RPCStatus Recv(RPCMessage *msg, int timeout) = 0;
};
} // namespace rpc
......
......@@ -98,27 +98,25 @@ class Receiver : public rpc::RPCReceiver {
* \brief Recv data from Sender
* \param msg pointer of data message
* \param send_id which sender current msg comes from
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return Status code
*
* (1) The Recv() API is blocking, which will not
* return until getting data from message queue.
* (2) The Recv() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
* (1) The Recv() API is thread-safe.
* (2) Memory allocated by communicator but will not own it after the function returns.
*/
virtual STATUS Recv(Message* msg, int* send_id) = 0;
virtual STATUS Recv(Message* msg, int* send_id, int timeout = 0) = 0;
/*!
* \brief Recv data from a specified Sender
* \param msg pointer of data message
* \param send_id sender's ID
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return Status code
*
* (1) The RecvFrom() API is blocking, which will not
* return until getting data from message queue.
* (2) The RecvFrom() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
* (1) The RecvFrom() API is thread-safe.
* (2) Memory allocated by communicator but will not own it after the function returns.
*/
virtual STATUS RecvFrom(Message* msg, int send_id) = 0;
virtual STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) = 0;
protected:
/*!
......
......@@ -275,32 +275,48 @@ bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking
return true;
}
void SocketReceiver::Recv(rpc::RPCMessage* msg) {
rpc::RPCStatus SocketReceiver::Recv(rpc::RPCMessage* msg, int timeout) {
Message rpc_meta_msg;
int send_id;
CHECK_EQ(Recv(
&rpc_meta_msg, &send_id), REMOVE_SUCCESS);
auto status = Recv(&rpc_meta_msg, &send_id, timeout);
if (status == QUEUE_EMPTY) {
DLOG(WARNING) << "Timed out when trying to receive rpc meta data after "
<< timeout << " milliseconds.";
return rpc::kRPCTimeOut;
}
CHECK_EQ(status, REMOVE_SUCCESS);
char* count_ptr = rpc_meta_msg.data+rpc_meta_msg.size-sizeof(int32_t);
int32_t nonempty_ndarray_count = *(reinterpret_cast<int32_t*>(count_ptr));
// Recv real ndarray data
std::vector<void*> buffer_list(nonempty_ndarray_count);
for (int i = 0; i < nonempty_ndarray_count; ++i) {
Message ndarray_data_msg;
CHECK_EQ(RecvFrom(
&ndarray_data_msg, send_id), REMOVE_SUCCESS);
status = RecvFrom(&ndarray_data_msg, send_id, timeout);
if (status == QUEUE_EMPTY) {
// As we cannot handle this timeout for now, let's treat it as fatal
// error.
LOG(FATAL) << "Timed out when trying to receive rpc ndarray data after "
<< timeout << " milliseconds.";
return rpc::kRPCTimeOut;
}
CHECK_EQ(status, REMOVE_SUCCESS);
buffer_list[i] = ndarray_data_msg.data;
}
StreamWithBuffer zc_read_strm(rpc_meta_msg.data, rpc_meta_msg.size-sizeof(int32_t), buffer_list);
zc_read_strm.Read(msg);
rpc_meta_msg.deallocator(&rpc_meta_msg);
return rpc::kRPCSuccess;
}
STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
STATUS SocketReceiver::Recv(Message* msg, int* send_id, int timeout) {
// queue_sem_ is a semaphore indicating how many elements in multiple
// message queues.
// When calling queue_sem_.Wait(), this Recv will be suspended until
// queue_sem_ > 0, decrease queue_sem_ by 1, then start to fetch a message.
queue_sem_.Wait();
// queue_sem_ > 0 or specified timeout expires, decrease queue_sem_ by 1,
// then start to fetch a message.
if (!queue_sem_.TimedWait(timeout)) {
return QUEUE_EMPTY;
}
for (;;) {
for (; mq_iter_ != msg_queue_.end(); ++mq_iter_) {
STATUS code = mq_iter_->second->Remove(msg, false);
......@@ -314,11 +330,16 @@ STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
}
mq_iter_ = msg_queue_.begin();
}
LOG(ERROR)
<< "Failed to remove message from queue due to unexpected queue status.";
return QUEUE_CLOSE;
}
STATUS SocketReceiver::RecvFrom(Message* msg, int send_id) {
STATUS SocketReceiver::RecvFrom(Message* msg, int send_id, int timeout) {
// Get message from specified message queue
queue_sem_.Wait();
if (!queue_sem_.TimedWait(timeout)) {
return QUEUE_EMPTY;
}
STATUS code = msg_queue_[send_id]->Remove(msg);
return code;
}
......
......@@ -172,34 +172,34 @@ class SocketReceiver : public Receiver {
/*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/
void Recv(rpc::RPCMessage* msg) override;
rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout) override;
/*!
* \brief Recv data from Sender. Actually removing data from msg_queue.
* \param msg pointer of data message
* \param send_id which sender current msg comes from
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return Status code
*
* (1) The Recv() API is blocking, which will not
* return until getting data from message queue.
* (2) The Recv() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
* (1) The Recv() API is thread-safe.
* (2) Memory allocated by communicator but will not own it after the function returns.
*/
STATUS Recv(Message* msg, int* send_id) override;
STATUS Recv(Message* msg, int* send_id, int timeout = 0) override;
/*!
* \brief Recv data from a specified Sender. Actually removing data from msg_queue.
* \param msg pointer of data message
* \param send_id sender's ID
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return Status code
*
* (1) The RecvFrom() API is blocking, which will not
* return until getting data from message queue.
* (2) The RecvFrom() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
* (1) The RecvFrom() API is thread-safe.
* (2) Memory allocated by communicator but will not own it after the function returns.
*/
STATUS RecvFrom(Message* msg, int send_id) override;
STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) override;
/*!
* \brief Finalize SocketReceiver
......
......@@ -69,10 +69,22 @@ RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {
}
RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
// ignore timeout now
CHECK_EQ(timeout, 0) << "rpc cannot support timeout now.";
RPCContext::getInstance()->receiver->Recv(msg);
return kRPCSuccess;
static constexpr int32_t retry_timeout = 5 * 1000; // milliseconds
RPCStatus status;
const int32_t real_timeout = timeout == 0 ? retry_timeout : timeout;
do {
status = RPCContext::getInstance()->receiver->Recv(msg, real_timeout);
if (status == kRPCTimeOut) {
static const std::string log_str = [real_timeout, timeout]() {
std::ostringstream oss;
oss << "Recv RPCMessage timeout in " << real_timeout << " ms."
<< (timeout == 0 ? " Retrying ..." : "");
return oss.str();
}();
DLOG(WARNING) << log_str;
}
} while (timeout == 0 && status == kRPCTimeOut);
return status;
}
void InitGlobalTpContext() {
......@@ -519,9 +531,12 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
}
});
// Recv remote message
for (int i = 0; i < msg_count; ++i) {
int recv_cnt = 0;
while (recv_cnt < msg_count) {
RPCMessage msg;
RecvRPCMessage(&msg, 0);
auto status = RecvRPCMessage(&msg, 0);
CHECK_EQ(status, kRPCSuccess);
++recv_cnt;
int part_id = msg.server_id / group_count;
char* data_char = static_cast<char*>(msg.tensors[0]->data);
dgl_id_t id_size = remote_ids[part_id].size();
......
......@@ -158,12 +158,6 @@ struct RPCContext {
}
};
/*! \brief RPC status flag */
enum RPCStatus {
kRPCSuccess = 0,
kRPCTimeOut,
};
/*!
* \brief Send out one RPC message.
*
......
......@@ -70,6 +70,12 @@ struct RPCMessage : public runtime::Object {
DGL_DEFINE_OBJECT_REF(RPCMessageRef, RPCMessage);
/*! \brief RPC status flag */
enum RPCStatus {
kRPCSuccess = 0,
kRPCTimeOut,
};
} // namespace rpc
} // namespace dgl
......
......@@ -8,9 +8,11 @@
#ifndef DGL_RPC_TENSORPIPE_QUEUE_H_
#define DGL_RPC_TENSORPIPE_QUEUE_H_
#include <dmlc/logging.h>
#include <condition_variable>
#include <deque>
#include <mutex>
#include <chrono>
namespace dgl {
namespace rpc {
......@@ -30,15 +32,23 @@ class Queue {
cv_.notify_all();
}
T pop() {
bool pop(T *msg, int timeout) {
std::unique_lock<std::mutex> lock(mutex_);
while (items_.size() == 0) {
cv_.wait(lock);
if (timeout == 0) {
DLOG(WARNING) << "Will wait infinitely until message is popped...";
cv_.wait(lock, [this] { return items_.size() > 0; });
} else {
if (!cv_.wait_for(lock, std::chrono::milliseconds(timeout),
[this] { return items_.size() > 0; })) {
DLOG(WARNING) << "Times out for popping message after " << timeout
<< " milliseconds.";
return false;
}
T t(std::move(items_.front()));
}
*msg = std::move(items_.front());
items_.pop_front();
cv_.notify_all();
return t;
return true;
}
private:
......
......@@ -193,7 +193,9 @@ void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
});
}
void TPReceiver::Recv(RPCMessage *msg) { *msg = std::move(queue_->pop()); }
RPCStatus TPReceiver::Recv(RPCMessage *msg, int timeout) {
return queue_->pop(msg, timeout) ? kRPCSuccess : kRPCTimeOut;
}
} // namespace rpc
} // namespace dgl
......@@ -135,8 +135,10 @@ class TPReceiver : public RPCReceiver {
/*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/
void Recv(RPCMessage* msg) override;
RPCStatus Recv(RPCMessage* msg, int timeout) override;
/*!
* \brief Finalize SocketReceiver
......
......@@ -7,6 +7,12 @@
#include <dmlc/logging.h>
#ifndef _WIN32
#include <errno.h>
#include <time.h>
#include <unistd.h>
#endif
namespace dgl {
namespace runtime {
......@@ -23,6 +29,12 @@ void Semaphore::Wait() {
WaitForSingleObject(sem_, INFINITE);
}
bool Semaphore::TimedWait(int) {
// Timed wait is not supported on WIN32.
Wait();
return true;
}
void Semaphore::Post() {
ReleaseSemaphore(sem_, 1, nullptr);
}
......@@ -37,6 +49,42 @@ void Semaphore::Wait() {
sem_wait(&sem_);
}
bool Semaphore::TimedWait(int timeout) {
// zero timeout means wait infinitely
if (timeout == 0) {
DLOG(WARNING) << "Will wait infinitely on semaphore until posted.";
Wait();
return true;
}
timespec ts;
if (clock_gettime(CLOCK_REALTIME, &ts) != 0) {
LOG(ERROR) << "Failed to get current time via clock_gettime. Errno: "
<< errno;
return false;
}
ts.tv_sec += timeout / MILLISECONDS_PER_SECOND;
ts.tv_nsec +=
(timeout % MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND;
if (ts.tv_nsec >= NANOSECONDS_PER_SECOND) {
ts.tv_nsec -= NANOSECONDS_PER_SECOND;
++ts.tv_sec;
}
int ret = 0;
while ((ret = sem_timedwait(&sem_, &ts) != 0) && errno == EINTR) {
continue;
}
if (ret != 0) {
if (errno == ETIMEDOUT) {
DLOG(WARNING) << "sem_timedwait timed out after " << timeout
<< " milliseconds.";
} else {
LOG(ERROR) << "sem_timedwait returns unexpectedly. Errno: " << errno;
}
return false;
}
return true;
}
void Semaphore::Post() {
sem_post(&sem_);
}
......
......@@ -29,16 +29,27 @@ class Semaphore {
* \brief blocking wait, decrease semaphore by 1
*/
void Wait();
/*!
* \brief timed wait, decrease semaphore by 1 or returns if times out
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
*/
bool TimedWait(int timeout);
/*!
* \brief increase semaphore by 1
*/
void Post();
private:
#ifdef _WIN32
HANDLE sem_;
#else
sem_t sem_;
#endif
enum {
MILLISECONDS_PER_SECOND = 1000,
NANOSECONDS_PER_MILLISECOND = 1000 * 1000,
NANOSECONDS_PER_SECOND = 1000 * 1000 * 1000
};
};
} // namespace runtime
......
......@@ -10,7 +10,7 @@
#include <vector>
#include <fstream>
#include <streambuf>
#include <chrono>
#include <stdlib.h>
#include <time.h>
......@@ -62,6 +62,49 @@ TEST(SocketCommunicatorTest, SendAndRecv) {
}
}
TEST(SocketCommunicatorTest, SendAndRecvTimeout) {
std::atomic_bool stop{false};
// start 1 client, connect to 1 server, send 2 messsage
auto client = std::thread([&stop]() {
SocketSender sender(kQueueSize, kThreadNum);
sender.ConnectReceiver(ip_addr[0], 0);
sender.ConnectReceiverFinalize(kMaxTryTimes);
for (int i = 0; i < 2; ++i) {
char *str_data = new char[9];
memcpy(str_data, "123456789", 9);
Message msg = {str_data, 9};
msg.deallocator = DefaultMessageDeleter;
EXPECT_EQ(sender.Send(msg, 0), ADD_SUCCESS);
}
while (!stop) {
}
sender.Finalize();
});
// start 1 server, accept 1 client, receive 2 message
auto server = std::thread([&stop]() {
SocketReceiver receiver(kQueueSize, kThreadNum);
receiver.Wait(ip_addr[0], 1);
Message msg;
int recv_id;
// receive 1st message
EXPECT_EQ(receiver.RecvFrom(&msg, 0, 0), REMOVE_SUCCESS);
EXPECT_EQ(string(msg.data, msg.size), string("123456789"));
msg.deallocator(&msg);
// receive 2nd message
EXPECT_EQ(receiver.Recv(&msg, &recv_id, 0), REMOVE_SUCCESS);
EXPECT_EQ(string(msg.data, msg.size), string("123456789"));
msg.deallocator(&msg);
// timed out
EXPECT_EQ(receiver.RecvFrom(&msg, 0, 1000), QUEUE_EMPTY);
EXPECT_EQ(receiver.Recv(&msg, &recv_id, 1000), QUEUE_EMPTY);
stop = true;
receiver.Finalize();
});
// join
client.join();
server.join();
}
void start_client() {
SocketSender sender(kQueueSize, kThreadNum);
for (int i = 0; i < kNumReceiver; ++i) {
......
......@@ -29,10 +29,12 @@ private:
}
for (int n = 0; n < kNumSender * kNumMessage * num_machines_; ++n) {
dgl::rpc::RPCMessage msg;
receiver.Recv(&msg);
if (receiver.Recv(&msg, 0) != dgl::rpc::kRPCSuccess) {
LOG(FATAL) << "Failed to receive message on Server~" << id;
}
bool eq = msg.data == std::string("123456789");
eq = eq && (msg.tensors.size() == kNumTensor);
for (int j = 0; j < kNumTensor; ++j) {
for (int j = 0; eq && j < kNumTensor; ++j) {
eq = eq && (msg.tensors[j].ToVector<int>().size() == kSizeTensor);
}
if (!eq) {
......
......@@ -83,6 +83,43 @@ class HelloRequest(dgl.distributed.Request):
res = HelloResponse(self.hello_str, self.integer, new_tensor)
return res
TIMEOUT_SERVICE_ID = 123456789
TIMEOUT_META = 'timeout_test'
class TimeoutResponse(dgl.distributed.Response):
def __init__(self, meta):
self.meta = meta
def __getstate__(self):
return self.meta
def __setstate__(self, state):
self.meta = state
class TimeoutRequest(dgl.distributed.Request):
def __init__(self, meta, timeout, response=True):
self.meta = meta
self.timeout = timeout
self.response = response
def __getstate__(self):
return self.meta, self.timeout, self.response
def __setstate__(self, state):
self.meta, self.timeout, self.response = state
def process_request(self, server_state):
assert self.meta == TIMEOUT_META
# convert from milliseconds to seconds
time.sleep(self.timeout/1000)
if not self.response:
return None
res = TimeoutResponse(self.meta)
return res
def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_servers=1, net_type='tensorpipe'):
print("Sleep 1 seconds to test client re-connect.")
time.sleep(1)
......@@ -90,6 +127,8 @@ def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_serv
None, local_g=None, partition_book=None, keep_alive=keep_alive)
dgl.distributed.register_service(
HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.register_service(
TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse)
print("Start server {}".format(server_id))
dgl.distributed.start_server(server_id=server_id,
ip_config=ip_config,
......@@ -134,6 +173,67 @@ def start_client(ip_config, group_id=0, num_servers=1, net_type='tensorpipe'):
assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
def start_client_timeout(ip_config, group_id=0, num_servers=1, net_type='tensorpipe'):
dgl.distributed.register_service(
TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse)
dgl.distributed.connect_to_server(
ip_config=ip_config, num_servers=num_servers, group_id=group_id, net_type=net_type)
timeout = 1 * 1000 # milliseconds
req = TimeoutRequest(TIMEOUT_META, timeout)
# test send and recv
dgl.distributed.send_request(0, req)
res = dgl.distributed.recv_response(timeout=int(timeout/2))
assert res is None
res = dgl.distributed.recv_response()
assert res.meta == TIMEOUT_META
# test remote_call
req = TimeoutRequest(TIMEOUT_META, timeout, response=False)
target_and_requests = []
for i in range(3):
target_and_requests.append((0, req))
expect_except = False
try:
res_list = dgl.distributed.remote_call(
target_and_requests, timeout=int(timeout/2))
except dgl.DGLError:
expect_except = True
assert expect_except
# test send_request_to_machine
req = TimeoutRequest(TIMEOUT_META, timeout)
dgl.distributed.send_request_to_machine(0, req)
res = dgl.distributed.recv_response(timeout=int(timeout/2))
assert res is None
res = dgl.distributed.recv_response()
assert res.meta == TIMEOUT_META
# test remote_call_to_machine
req = TimeoutRequest(TIMEOUT_META, timeout, response=False)
target_and_requests = []
for i in range(3):
target_and_requests.append((0, req))
expect_except = False
try:
res_list = dgl.distributed.remote_call_to_machine(
target_and_requests, timeout=int(timeout/2))
except dgl.DGLError:
expect_except = True
assert expect_except
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize("net_type", ['socket', 'tensorpipe'])
def test_rpc_timeout(net_type):
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = "rpc_ip_config.txt"
generate_ip_config(ip_config, 1, 1)
ctx = mp.get_context('spawn')
pserver = ctx.Process(target=start_server, args=(1, ip_config, 0, False, 1, net_type))
pclient = ctx.Process(target=start_client_timeout, args=(ip_config, 0, 1, net_type))
pserver.start()
pclient.start()
pserver.join()
pclient.join()
def test_serialize():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment