Unverified Commit 37467e25 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature][Dist] change TP::Receiver/TP::Sender for multiple connections (#3574)



* [Feature] enable TP::Receiver wait for any numbers of senders

* fix random unit test failure

* avoid endless future wait

* fix unit test failure

* fix seg fault when finalize wait in receiver

* [Feature] refactor sender connect logic and remove unnecessary sleeps in unit tests

* fix lint

* release RPCContext resources before process exits

* [Debug] TPReceiver wait start log

* [Debug] add log in get port

* [Debug] add log

* [ReDebug] revert time sleep in unit tests

* [Debug] remove sleep for test_distri,test_mp

* [debug] add more log

* [debug] add listen_booted_ flag

* [debug] restore commented code for queue

* [debug] sleep more in rpc_client

* restore change in tests

* Revert "restore change in tests"

This reverts commit 41a18926d181ec2517069389bfc41de2cc949280.

* Revert "[debug] sleep more in rpc_client"

This reverts commit a908e758eabca0a6ce62eb2e59baea02a840ac67.

* Revert "[debug] restore commented code for queue"

This reverts commit d3f993b3746e6bb6e2cc2f90204dd7e9461c6301.

* Revert "[debug] add listen_booted_ flag"

This reverts commit 244b2167d94942ff2a0acec8823b974975e52580.

* Revert "[debug] add more log"

This reverts commit 4b78447b0a575a824821dc7e25cca2246e6e30e2.

* Revert "[Debug] remove sleep for test_distri,test_mp"

This reverts commit e1df1aadcc8b1c2a0013ed77322ac391a8807612.

* remove debug code

* revert unnecessary change

* revert unnecessary changes

* always reset RPCContext when get started and reset all data

* remove time.sleep in dist tests

* fix lint

* reset envs before each dist test

* reset env properly

* add time sleep when start each server

* sleep for a while when boot server

* replace wait_thread with callback

* fix lint

* add dglconnect handshake check
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 95c0ff63
...@@ -226,6 +226,7 @@ def initialize(ip_config, num_servers=1, num_workers=0, ...@@ -226,6 +226,7 @@ def initialize(ip_config, num_servers=1, num_workers=0,
'Please define DGL_CONF_PATH to run DistGraph server' 'Please define DGL_CONF_PATH to run DistGraph server'
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',') formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',')
formats = [f.strip() for f in formats] formats = [f.strip() for f in formats]
rpc.reset()
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')), serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'), os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')), int(os.environ.get('DGL_NUM_SERVER')),
......
...@@ -13,7 +13,7 @@ from .. import backend as F ...@@ -13,7 +13,7 @@ from .. import backend as F
__all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \ 'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \
'receiver_wait', 'add_receiver_addr', 'sender_connect', 'read_ip_config', \ 'receiver_wait', 'connect_receiver', 'read_ip_config', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \ 'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \ 'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \ 'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
...@@ -138,7 +138,7 @@ def finalize_receiver(): ...@@ -138,7 +138,7 @@ def finalize_receiver():
""" """
_CAPI_DGLRPCFinalizeReceiver() _CAPI_DGLRPCFinalizeReceiver()
def receiver_wait(ip_addr, port, num_senders): def receiver_wait(ip_addr, port, num_senders, blocking=True):
"""Wait all of the senders' connections. """Wait all of the senders' connections.
This api will be blocked until all the senders connect to the receiver. This api will be blocked until all the senders connect to the receiver.
...@@ -151,11 +151,13 @@ def receiver_wait(ip_addr, port, num_senders): ...@@ -151,11 +151,13 @@ def receiver_wait(ip_addr, port, num_senders):
receiver's port receiver's port
num_senders : int num_senders : int
total number of senders total number of senders
blocking : bool
whether to wait blockingly
""" """
_CAPI_DGLRPCReceiverWait(ip_addr, int(port), int(num_senders)) _CAPI_DGLRPCReceiverWait(ip_addr, int(port), int(num_senders), blocking)
def add_receiver_addr(ip_addr, port, recv_id): def connect_receiver(ip_addr, port, recv_id):
"""Add Receiver's IP address to sender's namebook. """Connect to target receiver
Parameters Parameters
---------- ----------
...@@ -166,12 +168,7 @@ def add_receiver_addr(ip_addr, port, recv_id): ...@@ -166,12 +168,7 @@ def add_receiver_addr(ip_addr, port, recv_id):
recv_id : int recv_id : int
receiver's ID receiver's ID
""" """
_CAPI_DGLRPCAddReceiver(ip_addr, int(port), int(recv_id)) return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(recv_id))
def sender_connect():
"""Connect to all the receivers.
"""
_CAPI_DGLRPCSenderConnect()
def set_rank(rank): def set_rank(rank):
"""Set the rank of this process. """Set the rank of this process.
......
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import socket import socket
import atexit import atexit
import logging import logging
import time
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE from .constants import MAX_QUEUE_SIZE
...@@ -161,17 +162,17 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net ...@@ -161,17 +162,17 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net
for server_id, addr in server_namebook.items(): for server_id, addr in server_namebook.items():
server_ip = addr[1] server_ip = addr[1]
server_port = addr[2] server_port = addr[2]
rpc.add_receiver_addr(server_ip, server_port, server_id) while not rpc.connect_receiver(server_ip, server_port, server_id):
rpc.sender_connect() time.sleep(1)
# Get local usable IP address and port # Get local usable IP address and port
ip_addr = get_local_usable_addr(server_ip) ip_addr = get_local_usable_addr(server_ip)
client_ip, client_port = ip_addr.split(':') client_ip, client_port = ip_addr.split(':')
# wait server connect back
rpc.receiver_wait(client_ip, client_port, num_servers, blocking=False)
# Register client on server # Register client on server
register_req = rpc.ClientRegisterRequest(ip_addr) register_req = rpc.ClientRegisterRequest(ip_addr)
for server_id in range(num_servers): for server_id in range(num_servers):
rpc.send_request(server_id, register_req) rpc.send_request(server_id, register_req)
# wait server connect back
rpc.receiver_wait(client_ip, client_port, num_servers)
# recv client ID from server # recv client ID from server
res = rpc.recv_response() res = rpc.recv_response()
rpc.set_rank(res.client_id) rpc.set_rank(res.client_id)
......
"""Functions used by server.""" """Functions used by server."""
import time
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE from .constants import MAX_QUEUE_SIZE
...@@ -64,24 +62,23 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ ...@@ -64,24 +62,23 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
# wait all the senders connect to server. # wait all the senders connect to server.
# Once all the senders connect to server, server will not # Once all the senders connect to server, server will not
# accept new sender's connection # accept new sender's connection
print("Wait connections ...") print("Wait connections non-blockingly...")
rpc.receiver_wait(ip_addr, port, num_clients) rpc.receiver_wait(ip_addr, port, num_clients, blocking=False)
print("%d clients connected!" % num_clients)
rpc.set_num_client(num_clients) rpc.set_num_client(num_clients)
# Recv all the client's IP and assign ID to clients # Recv all the client's IP and assign ID to clients
addr_list = [] addr_list = []
client_namebook = {} client_namebook = {}
for _ in range(num_clients): for _ in range(num_clients):
# blocked until request is received
req, _ = rpc.recv_request() req, _ = rpc.recv_request()
assert isinstance(req, rpc.ClientRegisterRequest)
addr_list.append(req.ip_addr) addr_list.append(req.ip_addr)
addr_list.sort() addr_list.sort()
for client_id, addr in enumerate(addr_list): for client_id, addr in enumerate(addr_list):
client_namebook[client_id] = addr client_namebook[client_id] = addr
for client_id, addr in client_namebook.items(): for client_id, addr in client_namebook.items():
client_ip, client_port = addr.split(':') client_ip, client_port = addr.split(':')
rpc.add_receiver_addr(client_ip, client_port, client_id) assert rpc.connect_receiver(client_ip, client_port, client_id)
time.sleep(3) # wait client's socket ready. 3 sec is enough.
rpc.sender_connect()
if rpc.get_rank() == 0: # server_0 send all the IDs if rpc.get_rank() == 0: # server_0 send all the IDs
for client_id, _ in client_namebook.items(): for client_id, _ in client_namebook.items():
register_res = rpc.ClientRegisterResponse(client_id) register_res = rpc.ClientRegisterResponse(client_id)
......
...@@ -143,28 +143,22 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReceiverWait") ...@@ -143,28 +143,22 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReceiverWait")
std::string ip = args[0]; std::string ip = args[0];
int port = args[1]; int port = args[1];
int num_sender = args[2]; int num_sender = args[2];
bool blocking = args[3];
std::string addr; std::string addr;
addr = StringPrintf("tcp://%s:%d", ip.c_str(), port); addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
if (RPCContext::getInstance()->receiver->Wait(addr, num_sender) == false) { if (RPCContext::getInstance()->receiver->Wait(addr, num_sender, blocking) == false) {
LOG(FATAL) << "Wait sender socket failed."; LOG(FATAL) << "Wait sender socket failed.";
} }
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCAddReceiver") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0]; std::string ip = args[0];
int port = args[1]; int port = args[1];
int recv_id = args[2]; int recv_id = args[2];
std::string addr; std::string addr;
addr = StringPrintf("tcp://%s:%d", ip.c_str(), port); addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
RPCContext::getInstance()->sender->AddReceiver(addr, recv_id); *rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSenderConnect")
.set_body([](DGLArgs args, DGLRetValue* rv) {
if (RPCContext::getInstance()->sender->Connect() == false) {
LOG(FATAL) << "Sender connection failed.";
}
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
......
...@@ -113,12 +113,15 @@ struct RPCContext { ...@@ -113,12 +113,15 @@ struct RPCContext {
t->rank = -1; t->rank = -1;
t->machine_id = -1; t->machine_id = -1;
t->num_machines = 0; t->num_machines = 0;
t->msg_seq = 0;
t->num_servers = 0;
t->num_clients = 0; t->num_clients = 0;
t->barrier_count = 0; t->barrier_count = 0;
t->num_servers_per_machine = 0; t->num_servers_per_machine = 0;
t->sender.reset(); t->sender.reset();
t->receiver.reset(); t->receiver.reset();
t->ctx.reset(); t->ctx.reset();
t->server_state.reset();
} }
}; };
......
...@@ -20,54 +20,48 @@ namespace rpc { ...@@ -20,54 +20,48 @@ namespace rpc {
using namespace tensorpipe; using namespace tensorpipe;
void TPSender::AddReceiver(const std::string& addr, int recv_id) { bool TPSender::ConnectReceiver(const std::string &addr, int recv_id) {
receiver_addrs_[recv_id] = addr; if (pipes_.find(recv_id) != pipes_.end()) {
} LOG(WARNING) << "Duplicate recv_id[" << recv_id << "]. Ignoring...";
return true;
bool TPSender::Connect() { }
for (const auto& kv : receiver_addrs_) { std::shared_ptr<Pipe> pipe;
std::shared_ptr<Pipe> pipe; pipe = context->connect(addr);
for (;;) { auto done = std::make_shared<std::promise<bool>>();
pipe = context->connect(kv.second); tensorpipe::Message tpmsg;
std::promise<bool> done; tpmsg.metadata = "dglconnect";
tensorpipe::Message tpmsg; pipe->write(tpmsg, [done](const tensorpipe::Error &error) {
tpmsg.metadata = "dglconnect"; if (error) {
pipe->write(tpmsg, [&done](const tensorpipe::Error& error) { LOG(WARNING) << "Error occurred when write to pipe: " << error.what();
if (error) { done->set_value(false);
done.set_value(false); } else {
} else { done->set_value(true);
done.set_value(true);
}
});
if (done.get_future().get()) {
break;
} else {
sleep(5);
LOG(INFO) << "Cannot connect to remove server " << kv.second
<< ". Wait to retry";
}
} }
pipes_[kv.first] = pipe; });
if (!done->get_future().get()) {
LOG(WARNING) << "Failed to connect to receiver[" << addr << "].";
return false;
} }
pipes_[recv_id] = pipe;
return true; return true;
} }
void TPSender::Send(const RPCMessage& msg, int recv_id) { void TPSender::Send(const RPCMessage &msg, int recv_id) {
auto pipe = pipes_[recv_id]; auto pipe = pipes_[recv_id];
tensorpipe::Message tp_msg; tensorpipe::Message tp_msg;
std::string* zerocopy_blob_ptr = &tp_msg.metadata; std::string *zerocopy_blob_ptr = &tp_msg.metadata;
StreamWithBuffer zc_write_strm(zerocopy_blob_ptr, true); StreamWithBuffer zc_write_strm(zerocopy_blob_ptr, true);
zc_write_strm.Write(msg); zc_write_strm.Write(msg);
int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size(); int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
zerocopy_blob_ptr->append(reinterpret_cast<char*>(&nonempty_ndarray_count), zerocopy_blob_ptr->append(reinterpret_cast<char *>(&nonempty_ndarray_count),
sizeof(int32_t)); sizeof(int32_t));
tp_msg.tensors.resize(nonempty_ndarray_count); tp_msg.tensors.resize(nonempty_ndarray_count);
// Hold the NDArray that ensure it's valid until write operation completes // Hold the NDArray that ensure it's valid until write operation completes
auto ndarray_holder = std::make_shared<std::vector<NDArray>>(); auto ndarray_holder = std::make_shared<std::vector<NDArray>>();
ndarray_holder->resize(nonempty_ndarray_count); ndarray_holder->resize(nonempty_ndarray_count);
auto& buffer_list = zc_write_strm.buffer_list(); auto &buffer_list = zc_write_strm.buffer_list();
for (int i = 0; i < buffer_list.size(); i++) { for (size_t i = 0; i < buffer_list.size(); i++) {
auto& ptr = buffer_list[i]; auto &ptr = buffer_list[i];
(*ndarray_holder.get())[i] = ptr.tensor; (*ndarray_holder.get())[i] = ptr.tensor;
tensorpipe::CpuBuffer cpu_buffer; tensorpipe::CpuBuffer cpu_buffer;
cpu_buffer.ptr = ptr.data; cpu_buffer.ptr = ptr.data;
...@@ -78,7 +72,7 @@ void TPSender::Send(const RPCMessage& msg, int recv_id) { ...@@ -78,7 +72,7 @@ void TPSender::Send(const RPCMessage& msg, int recv_id) {
} }
} }
pipe->write(tp_msg, pipe->write(tp_msg,
[ndarray_holder, recv_id](const tensorpipe::Error& error) { [ndarray_holder, recv_id](const tensorpipe::Error &error) {
if (error) { if (error) {
LOG(FATAL) << "Failed to send message to " << recv_id LOG(FATAL) << "Failed to send message to " << recv_id
<< ". Details: " << error.what(); << ". Details: " << error.what();
...@@ -86,37 +80,69 @@ void TPSender::Send(const RPCMessage& msg, int recv_id) { ...@@ -86,37 +80,69 @@ void TPSender::Send(const RPCMessage& msg, int recv_id) {
}); });
} }
void TPSender::Finalize() {} void TPSender::Finalize() {
void TPReceiver::Finalize() {} for (auto &&p : pipes_) {
p.second->close();
}
pipes_.clear();
}
bool TPReceiver::Wait(const std::string& addr, int num_sender) { void TPReceiver::Finalize() {
listener = context->listen({addr}); listener_->close();
for (int i = 0; i < num_sender; i++) { for (auto &&p : pipes_) {
std::promise<std::shared_ptr<Pipe>> pipeProm; p.second->close();
listener->accept([&](const Error& error, std::shared_ptr<Pipe> pipe) { }
if (error) { pipes_.clear();
LOG(WARNING) << error.what(); }
}
pipeProm.set_value(std::move(pipe)); bool TPReceiver::Wait(const std::string &addr, int num_sender, bool blocking) {
}); if (listener_) {
std::shared_ptr<Pipe> pipe = pipeProm.get_future().get(); LOG(WARNING) << "TPReceiver::Wait() has been called already. Ignoring...";
std::promise<bool> checkConnect; return true;
pipe->readDescriptor( }
[pipe, &checkConnect](const Error& error, Descriptor descriptor) { LOG(INFO) << "TPReceiver starts to wait on [" << addr << "].";
Allocation allocation; listener_ = context->listen({addr});
checkConnect.set_value(descriptor.metadata == "dglconnect"); listener_->accept([this](const Error &error, std::shared_ptr<Pipe> pipe) {
pipe->read(allocation, [](const Error& error) {}); OnAccepted(error, pipe);
}); });
CHECK(checkConnect.get_future().get()) << "Invalid connect message."; while (blocking && (num_sender != num_connected_)) {
pipes_[i] = pipe;
ReceiveFromPipe(pipe, queue_);
} }
return true; return true;
} }
void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
if (error) {
if (error.isOfType<ListenerClosedError>()) {
// Expected.
} else {
LOG(WARNING) << "Unexpected error when accepting incoming pipe: " << error.what();
}
return;
}
// Accept the next connection request
listener_->accept([this](const Error &error, std::shared_ptr<Pipe> pipe) {
OnAccepted(error, pipe);
});
// read the handshake message: "dglconnect"
pipe->readDescriptor([pipe, this](const Error &error, Descriptor descriptor) {
if (error) {
LOG(WARNING) << "Unexpected error when reading from accepted pipe: " << error.what();
return;
}
Allocation allocation;
pipe->read(allocation, [](const Error &error) {});
CHECK(descriptor.metadata == "dglconnect") << "Invalid connect message.";
pipes_[num_connected_] = pipe;
ReceiveFromPipe(pipe, queue_);
++num_connected_;
});
}
void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe, void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
std::shared_ptr<RPCMessageQueue> queue) { std::shared_ptr<RPCMessageQueue> queue) {
pipe->readDescriptor([pipe, queue = std::move(queue)](const Error& error, pipe->readDescriptor([pipe, queue = std::move(queue)](const Error &error,
Descriptor descriptor) { Descriptor descriptor) {
if (error) { if (error) {
// Error may happen when the pipe is closed // Error may happen when the pipe is closed
...@@ -128,41 +154,41 @@ void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe, ...@@ -128,41 +154,41 @@ void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
int tensorsize = descriptor.tensors.size(); int tensorsize = descriptor.tensors.size();
if (tensorsize > 0) { if (tensorsize > 0) {
allocation.tensors.resize(tensorsize); allocation.tensors.resize(tensorsize);
for (int i = 0; i < descriptor.tensors.size(); i++) { for (size_t i = 0; i < descriptor.tensors.size(); i++) {
tensorpipe::CpuBuffer cpu_buffer; tensorpipe::CpuBuffer cpu_buffer;
cpu_buffer.ptr = new char[descriptor.tensors[i].length]; cpu_buffer.ptr = new char[descriptor.tensors[i].length];
allocation.tensors[i].buffer = cpu_buffer; allocation.tensors[i].buffer = cpu_buffer;
} }
} }
pipe->read( pipe->read(allocation, [allocation, descriptor = std::move(descriptor),
allocation, [allocation, descriptor = std::move(descriptor), queue = std::move(queue),
queue = std::move(queue), pipe](const Error& error) { pipe](const Error &error) {
if (error) { if (error) {
// Because we always have a read event posted to the epoll, // Because we always have a read event posted to the epoll,
// Therefore when pipe is closed, error will be raised. // Therefore when pipe is closed, error will be raised.
// But this error is expected. // But this error is expected.
// Other error is not expected. But we cannot identify the error with each // Other error is not expected. But we cannot identify the error with
// Other for now. Thus here we skip handling for all errors // each Other for now. Thus here we skip handling for all errors
return; return;
} }
char* meta_msg_begin = const_cast<char*>(&descriptor.metadata[0]); char *meta_msg_begin = const_cast<char *>(&descriptor.metadata[0]);
std::vector<void*> buffer_list(descriptor.tensors.size()); std::vector<void *> buffer_list(descriptor.tensors.size());
for (int i = 0; i < descriptor.tensors.size(); i++) { for (size_t i = 0; i < descriptor.tensors.size(); i++) {
buffer_list[i] = allocation.tensors[i].buffer.unwrap<CpuBuffer>().ptr; buffer_list[i] = allocation.tensors[i].buffer.unwrap<CpuBuffer>().ptr;
} }
StreamWithBuffer zc_read_strm( StreamWithBuffer zc_read_strm(
meta_msg_begin, descriptor.metadata.size() - sizeof(int32_t), meta_msg_begin, descriptor.metadata.size() - sizeof(int32_t),
buffer_list); buffer_list);
RPCMessage msg; RPCMessage msg;
zc_read_strm.Read(&msg); zc_read_strm.Read(&msg);
queue->push(msg); queue->push(msg);
TPReceiver::ReceiveFromPipe(pipe, queue); TPReceiver::ReceiveFromPipe(pipe, queue);
}); });
}); });
} }
void TPReceiver::Recv(RPCMessage* msg) { *msg = std::move(queue_->pop()); } void TPReceiver::Recv(RPCMessage *msg) { *msg = std::move(queue_->pop()); }
} // namespace rpc } // namespace rpc
} // namespace dgl } // namespace dgl
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <atomic>
#include "./queue.h" #include "./queue.h"
namespace dgl { namespace dgl {
...@@ -42,21 +42,19 @@ class TPSender { ...@@ -42,21 +42,19 @@ class TPSender {
} }
/*! /*!
* \brief Add receiver's address and ID to the sender's namebook * \brief Sender destructor
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param id receiver's ID
*
* AddReceiver() is not thread-safe and only one thread can invoke this API.
*/ */
void AddReceiver(const std::string& addr, int recv_id); ~TPSender() { Finalize(); }
/*! /*!
* \brief Connect with all the Receivers * \brief Connect to receiver with address and ID
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID
* \return True for success and False for fail * \return True for success and False for fail
* *
* Connect() is not thread-safe and only one thread can invoke this API. * ConnectReceiver() is not thread-safe and only one thread can invoke this API.
*/ */
bool Connect(); bool ConnectReceiver(const std::string& addr, int recv_id);
/*! /*!
* \brief Send RPCMessage to specified Receiver. * \brief Send RPCMessage to specified Receiver.
...@@ -109,15 +107,21 @@ class TPReceiver { ...@@ -109,15 +107,21 @@ class TPReceiver {
queue_ = std::make_shared<RPCMessageQueue>(); queue_ = std::make_shared<RPCMessageQueue>();
} }
/*!
* \brief Receiver destructor
*/
~TPReceiver() { Finalize(); }
/*! /*!
* \brief Wait for all the Senders to connect * \brief Wait for all the Senders to connect
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50051' * \param addr Networking address, e.g., 'tcp://127.0.0.1:50051'
* \param num_sender total number of Senders * \param num_sender total number of Senders
* \param blocking whether to wait blockingly
* \return True for success and False for fail * \return True for success and False for fail
* *
* Wait() is not thread-safe and only one thread can invoke this API. * Wait() is not thread-safe and only one thread can invoke this API.
*/ */
bool Wait(const std::string& addr, int num_sender); bool Wait(const std::string &addr, int num_sender, bool blocking = true);
/*! /*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue. * \brief Recv RPCMessage from Sender. Actually removing data from queue.
...@@ -151,6 +155,12 @@ class TPReceiver { ...@@ -151,6 +155,12 @@ class TPReceiver {
static void ReceiveFromPipe(std::shared_ptr<tensorpipe::Pipe> pipe, static void ReceiveFromPipe(std::shared_ptr<tensorpipe::Pipe> pipe,
std::shared_ptr<RPCMessageQueue> queue); std::shared_ptr<RPCMessageQueue> queue);
private:
/*!
* \brief Callback for new connection is accepted.
*/
void OnAccepted(const tensorpipe::Error&, std::shared_ptr<tensorpipe::Pipe>);
private: private:
/*! /*!
* \brief number of sender * \brief number of sender
...@@ -178,6 +188,16 @@ class TPReceiver { ...@@ -178,6 +188,16 @@ class TPReceiver {
* \brief RPCMessage queue * \brief RPCMessage queue
*/ */
std::shared_ptr<RPCMessageQueue> queue_; std::shared_ptr<RPCMessageQueue> queue_;
/*!
* \brief number of accepted connections
*/
std::atomic<int32_t> num_connected_{0};
/*!
* \brief listner
*/
std::shared_ptr<tensorpipe::Listener> listener_{nullptr};
}; };
} // namespace rpc } // namespace rpc
......
...@@ -18,6 +18,7 @@ import backend as F ...@@ -18,6 +18,7 @@ import backend as F
import math import math
import unittest import unittest
import pickle import pickle
from utils import reset_envs
if os.name != 'nt': if os.name != 'nt':
import fcntl import fcntl
...@@ -96,7 +97,6 @@ def check_dist_graph_empty(g, num_clients, num_nodes, num_edges): ...@@ -96,7 +97,6 @@ def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
print('end') print('end')
def run_client_empty(graph_name, part_id, server_count, num_clients, num_nodes, num_edges): def run_client_empty(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count) os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt") dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name), gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
...@@ -140,7 +140,6 @@ def check_server_client_empty(shared_mem, num_servers, num_clients): ...@@ -140,7 +140,6 @@ def check_server_client_empty(shared_mem, num_servers, num_clients):
print('clients have terminated') print('clients have terminated')
def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges): def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count) os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt") dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name), gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
...@@ -149,7 +148,6 @@ def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_ed ...@@ -149,7 +148,6 @@ def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_ed
check_dist_graph(g, num_clients, num_nodes, num_edges) check_dist_graph(g, num_clients, num_nodes, num_edges)
def run_emb_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges): def run_emb_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count) os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt") dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name), gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
...@@ -158,7 +156,6 @@ def run_emb_client(graph_name, part_id, server_count, num_clients, num_nodes, nu ...@@ -158,7 +156,6 @@ def run_emb_client(graph_name, part_id, server_count, num_clients, num_nodes, nu
check_dist_emb(g, num_clients, num_nodes, num_edges) check_dist_emb(g, num_clients, num_nodes, num_edges)
def run_client_hierarchy(graph_name, part_id, server_count, node_mask, edge_mask, return_dict): def run_client_hierarchy(graph_name, part_id, server_count, node_mask, edge_mask, return_dict):
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count) os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt") dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name), gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
...@@ -440,7 +437,6 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients): ...@@ -440,7 +437,6 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
def run_client_hetero(graph_name, part_id, server_count, num_clients, num_nodes, num_edges): def run_client_hetero(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count) os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt") dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name), gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
...@@ -587,6 +583,7 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients): ...@@ -587,6 +583,7 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support") @unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
def test_server_client(): def test_server_client():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
check_server_client_hierarchy(False, 1, 4) check_server_client_hierarchy(False, 1, 4)
check_server_client_empty(True, 1, 1) check_server_client_empty(True, 1, 1)
...@@ -600,6 +597,7 @@ def test_server_client(): ...@@ -600,6 +597,7 @@ def test_server_client():
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Mxnet doesn't support distributed DistEmbedding") @unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Mxnet doesn't support distributed DistEmbedding")
def test_dist_emb_server_client(): def test_dist_emb_server_client():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
check_dist_emb_server_client(True, 1, 1) check_dist_emb_server_client(True, 1, 1)
check_dist_emb_server_client(False, 1, 1) check_dist_emb_server_client(False, 1, 1)
...@@ -608,6 +606,7 @@ def test_dist_emb_server_client(): ...@@ -608,6 +606,7 @@ def test_dist_emb_server_client():
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support") @unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
def test_standalone(): def test_standalone():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'standalone' os.environ['DGL_DIST_MODE'] = 'standalone'
g = create_random_graph(10000) g = create_random_graph(10000)
...@@ -626,6 +625,7 @@ def test_standalone(): ...@@ -626,6 +625,7 @@ def test_standalone():
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Mxnet doesn't support distributed DistEmbedding") @unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Mxnet doesn't support distributed DistEmbedding")
def test_standalone_node_emb(): def test_standalone_node_emb():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'standalone' os.environ['DGL_DIST_MODE'] = 'standalone'
g = create_random_graph(10000) g = create_random_graph(10000)
......
...@@ -10,7 +10,7 @@ import multiprocessing as mp ...@@ -10,7 +10,7 @@ import multiprocessing as mp
import numpy as np import numpy as np
import backend as F import backend as F
import time import time
from utils import get_local_usable_addr from utils import get_local_usable_addr, reset_envs
from pathlib import Path from pathlib import Path
import pytest import pytest
from scipy import sparse as spsp from scipy import sparse as spsp
...@@ -93,7 +93,6 @@ def check_rpc_sampling(tmpdir, num_server): ...@@ -93,7 +93,6 @@ def check_rpc_sampling(tmpdir, num_server):
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
time.sleep(3)
sampled_graph = start_sample_client(0, tmpdir, num_server > 1) sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
print("Done sampling") print("Done sampling")
for p in pserver_list: for p in pserver_list:
...@@ -129,7 +128,6 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server): ...@@ -129,7 +128,6 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server):
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
time.sleep(3)
eids = F.tensor(np.random.randint(g.number_of_edges(), size=100)) eids = F.tensor(np.random.randint(g.number_of_edges(), size=100))
u, v = g.find_edges(orig_eid[eids]) u, v = g.find_edges(orig_eid[eids])
du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids) du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
...@@ -179,7 +177,6 @@ def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server): ...@@ -179,7 +177,6 @@ def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
time.sleep(3)
eids = F.tensor(np.random.randint(g.number_of_edges('r1'), size=100)) eids = F.tensor(np.random.randint(g.number_of_edges('r1'), size=100))
u, v = g.find_edges(orig_eid['r1'][eids], etype='r1') u, v = g.find_edges(orig_eid['r1'][eids], etype='r1')
du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids, etype='r1') du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids, etype='r1')
...@@ -194,6 +191,7 @@ def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server): ...@@ -194,6 +191,7 @@ def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support") @unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
@pytest.mark.parametrize("num_server", [1, 2]) @pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_find_edges_shuffle(num_server): def test_rpc_find_edges_shuffle(num_server):
reset_envs()
import tempfile import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -225,7 +223,6 @@ def check_rpc_get_degree_shuffle(tmpdir, num_server): ...@@ -225,7 +223,6 @@ def check_rpc_get_degree_shuffle(tmpdir, num_server):
for i in range(num_server): for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_get_degrees.json', i) part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_get_degrees.json', i)
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id'] orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
time.sleep(3)
nids = F.tensor(np.random.randint(g.number_of_nodes(), size=100)) nids = F.tensor(np.random.randint(g.number_of_nodes(), size=100))
in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(0, tmpdir, num_server > 1, nids) in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(0, tmpdir, num_server > 1, nids)
...@@ -246,6 +243,7 @@ def check_rpc_get_degree_shuffle(tmpdir, num_server): ...@@ -246,6 +243,7 @@ def check_rpc_get_degree_shuffle(tmpdir, num_server):
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support") @unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
@pytest.mark.parametrize("num_server", [1, 2]) @pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_get_degree_shuffle(num_server): def test_rpc_get_degree_shuffle(num_server):
reset_envs()
import tempfile import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -255,6 +253,7 @@ def test_rpc_get_degree_shuffle(num_server): ...@@ -255,6 +253,7 @@ def test_rpc_get_degree_shuffle(num_server):
#@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now') #@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skip('Only support partition with shuffle') @unittest.skip('Only support partition with shuffle')
def test_rpc_sampling(): def test_rpc_sampling():
reset_envs()
import tempfile import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -282,7 +281,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server): ...@@ -282,7 +281,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server):
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
time.sleep(3)
sampled_graph = start_sample_client(0, tmpdir, num_server > 1) sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
print("Done sampling") print("Done sampling")
for p in pserver_list: for p in pserver_list:
...@@ -379,7 +377,6 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server): ...@@ -379,7 +377,6 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
time.sleep(3)
block, gpb = start_hetero_sample_client(0, tmpdir, num_server > 1, block, gpb = start_hetero_sample_client(0, tmpdir, num_server > 1,
nodes = {'n3': [0, 10, 99, 66, 124, 208]}) nodes = {'n3': [0, 10, 99, 66, 124, 208]})
print("Done sampling") print("Done sampling")
...@@ -448,7 +445,6 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server): ...@@ -448,7 +445,6 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
time.sleep(3)
deg = get_degrees(g, orig_nids['n3'], 'n3') deg = get_degrees(g, orig_nids['n3'], 'n3')
empty_nids = F.nonzero_1d(deg == 0) empty_nids = F.nonzero_1d(deg == 0)
block, gpb = start_hetero_sample_client(0, tmpdir, num_server > 1, block, gpb = start_hetero_sample_client(0, tmpdir, num_server > 1,
...@@ -480,7 +476,6 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server): ...@@ -480,7 +476,6 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server):
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
time.sleep(3)
fanout = 3 fanout = 3
block, gpb = start_hetero_etype_sample_client(0, tmpdir, num_server > 1, fanout, block, gpb = start_hetero_etype_sample_client(0, tmpdir, num_server > 1, fanout,
nodes={'n3': [0, 10, 99, 66, 124, 208]}) nodes={'n3': [0, 10, 99, 66, 124, 208]})
...@@ -545,7 +540,6 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server): ...@@ -545,7 +540,6 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
time.sleep(3)
fanout = 3 fanout = 3
deg = get_degrees(g, orig_nids['n3'], 'n3') deg = get_degrees(g, orig_nids['n3'], 'n3')
empty_nids = F.nonzero_1d(deg == 0) empty_nids = F.nonzero_1d(deg == 0)
...@@ -564,6 +558,7 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server): ...@@ -564,6 +558,7 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support") @unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
@pytest.mark.parametrize("num_server", [1, 2]) @pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_sampling_shuffle(num_server): def test_rpc_sampling_shuffle(num_server):
reset_envs()
import tempfile import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -637,6 +632,7 @@ def check_standalone_etype_sampling_heterograph(tmpdir, reshuffle): ...@@ -637,6 +632,7 @@ def check_standalone_etype_sampling_heterograph(tmpdir, reshuffle):
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now') @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_standalone_sampling(): def test_standalone_sampling():
reset_envs()
import tempfile import tempfile
os.environ['DGL_DIST_MODE'] = 'standalone' os.environ['DGL_DIST_MODE'] = 'standalone'
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -680,7 +676,6 @@ def check_rpc_in_subgraph_shuffle(tmpdir, num_server): ...@@ -680,7 +676,6 @@ def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
pserver_list.append(p) pserver_list.append(p)
nodes = [0, 10, 99, 66, 1024, 2008] nodes = [0, 10, 99, 66, 1024, 2008]
time.sleep(3)
sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes) sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
for p in pserver_list: for p in pserver_list:
p.join() p.join()
...@@ -710,6 +705,7 @@ def check_rpc_in_subgraph_shuffle(tmpdir, num_server): ...@@ -710,6 +705,7 @@ def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now') @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_in_subgraph(): def test_rpc_in_subgraph():
reset_envs()
import tempfile import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -719,6 +715,7 @@ def test_rpc_in_subgraph(): ...@@ -719,6 +715,7 @@ def test_rpc_in_subgraph():
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now') @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support") @unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
def test_standalone_etype_sampling(): def test_standalone_etype_sampling():
reset_envs()
import tempfile import tempfile
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
os.environ['DGL_DIST_MODE'] = 'standalone' os.environ['DGL_DIST_MODE'] = 'standalone'
......
...@@ -9,7 +9,7 @@ import sys ...@@ -9,7 +9,7 @@ import sys
import multiprocessing as mp import multiprocessing as mp
import numpy as np import numpy as np
import time import time
from utils import get_local_usable_addr from utils import get_local_usable_addr, reset_envs
from pathlib import Path from pathlib import Path
from dgl.distributed import DistGraphServer, DistGraph, DistDataLoader from dgl.distributed import DistGraphServer, DistGraph, DistDataLoader
import pytest import pytest
...@@ -103,6 +103,7 @@ def start_dist_dataloader(rank, tmpdir, num_server, drop_last, orig_nid, orig_ei ...@@ -103,6 +103,7 @@ def start_dist_dataloader(rank, tmpdir, num_server, drop_last, orig_nid, orig_ei
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') @unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
def test_standalone(tmpdir): def test_standalone(tmpdir):
reset_envs()
ip_config = open("mp_ip_config.txt", "w") ip_config = open("mp_ip_config.txt", "w")
for _ in range(1): for _ in range(1):
ip_config.write('{}\n'.format(get_local_usable_addr())) ip_config.write('{}\n'.format(get_local_usable_addr()))
...@@ -198,7 +199,6 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers): ...@@ -198,7 +199,6 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers):
p.start() p.start()
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
time.sleep(3)
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
os.environ['DGL_NUM_SAMPLER'] = str(num_workers) os.environ['DGL_NUM_SAMPLER'] = str(num_workers)
ptrainer_list = [] ptrainer_list = []
...@@ -206,7 +206,6 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers): ...@@ -206,7 +206,6 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers):
p = ctx.Process(target=start_dist_neg_dataloader, args=( p = ctx.Process(target=start_dist_neg_dataloader, args=(
0, tmpdir, num_server, num_workers, orig_nid, g)) 0, tmpdir, num_server, num_workers, orig_nid, g))
p.start() p.start()
time.sleep(1)
ptrainer_list.append(p) ptrainer_list.append(p)
for p in pserver_list: for p in pserver_list:
...@@ -221,6 +220,7 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers): ...@@ -221,6 +220,7 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers):
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
@pytest.mark.parametrize("reshuffle", [True, False]) @pytest.mark.parametrize("reshuffle", [True, False])
def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle): def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):
reset_envs()
ip_config = open("mp_ip_config.txt", "w") ip_config = open("mp_ip_config.txt", "w")
for _ in range(num_server): for _ in range(num_server):
ip_config.write('{}\n'.format(get_local_usable_addr())) ip_config.write('{}\n'.format(get_local_usable_addr()))
...@@ -244,13 +244,11 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle): ...@@ -244,13 +244,11 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
time.sleep(3)
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
os.environ['DGL_NUM_SAMPLER'] = str(num_workers) os.environ['DGL_NUM_SAMPLER'] = str(num_workers)
ptrainer = ctx.Process(target=start_dist_dataloader, args=( ptrainer = ctx.Process(target=start_dist_dataloader, args=(
0, tmpdir, num_server, drop_last, orig_nid, orig_eid)) 0, tmpdir, num_server, drop_last, orig_nid, orig_eid))
ptrainer.start() ptrainer.start()
time.sleep(1)
for p in pserver_list: for p in pserver_list:
p.join() p.join()
...@@ -387,7 +385,6 @@ def check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type): ...@@ -387,7 +385,6 @@ def check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type):
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
time.sleep(3)
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
os.environ['DGL_NUM_SAMPLER'] = str(num_workers) os.environ['DGL_NUM_SAMPLER'] = str(num_workers)
ptrainer_list = [] ptrainer_list = []
...@@ -395,13 +392,11 @@ def check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type): ...@@ -395,13 +392,11 @@ def check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type):
p = ctx.Process(target=start_node_dataloader, args=( p = ctx.Process(target=start_node_dataloader, args=(
0, tmpdir, num_server, num_workers, orig_nid, orig_eid, g)) 0, tmpdir, num_server, num_workers, orig_nid, orig_eid, g))
p.start() p.start()
time.sleep(1)
ptrainer_list.append(p) ptrainer_list.append(p)
elif dataloader_type == 'edge': elif dataloader_type == 'edge':
p = ctx.Process(target=start_edge_dataloader, args=( p = ctx.Process(target=start_edge_dataloader, args=(
0, tmpdir, num_server, num_workers, orig_nid, orig_eid, g)) 0, tmpdir, num_server, num_workers, orig_nid, orig_eid, g))
p.start() p.start()
time.sleep(1)
ptrainer_list.append(p) ptrainer_list.append(p)
for p in pserver_list: for p in pserver_list:
p.join() p.join()
...@@ -430,6 +425,7 @@ def create_random_hetero(): ...@@ -430,6 +425,7 @@ def create_random_hetero():
@pytest.mark.parametrize("num_workers", [0, 4]) @pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"]) @pytest.mark.parametrize("dataloader_type", ["node", "edge"])
def test_dataloader(tmpdir, num_server, num_workers, dataloader_type): def test_dataloader(tmpdir, num_server, num_workers, dataloader_type):
reset_envs()
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type) check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type)
g = create_random_hetero() g = create_random_hetero()
...@@ -441,6 +437,7 @@ def test_dataloader(tmpdir, num_server, num_workers, dataloader_type): ...@@ -441,6 +437,7 @@ def test_dataloader(tmpdir, num_server, num_workers, dataloader_type):
@pytest.mark.parametrize("num_server", [3]) @pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("num_workers", [0, 4]) @pytest.mark.parametrize("num_workers", [0, 4])
def test_neg_dataloader(tmpdir, num_server, num_workers): def test_neg_dataloader(tmpdir, num_server, num_workers):
reset_envs()
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
check_neg_dataloader(g, tmpdir, num_server, num_workers) check_neg_dataloader(g, tmpdir, num_server, num_workers)
g = create_random_hetero() g = create_random_hetero()
......
...@@ -7,6 +7,7 @@ import backend as F ...@@ -7,6 +7,7 @@ import backend as F
import unittest, pytest import unittest, pytest
import multiprocessing as mp import multiprocessing as mp
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from utils import reset_envs
if os.name != 'nt': if os.name != 'nt':
import fcntl import fcntl
...@@ -108,8 +109,8 @@ class HelloRequest(dgl.distributed.Request): ...@@ -108,8 +109,8 @@ class HelloRequest(dgl.distributed.Request):
return res return res
def start_server(num_clients, ip_config, server_id=0): def start_server(num_clients, ip_config, server_id=0):
print("Sleep 5 seconds to test client re-connect.") print("Sleep 2 seconds to test client re-connect.")
time.sleep(5) time.sleep(2)
server_state = dgl.distributed.ServerState(None, local_g=None, partition_book=None) server_state = dgl.distributed.ServerState(None, local_g=None, partition_book=None)
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse) dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
print("Start server {}".format(server_id)) print("Start server {}".format(server_id))
...@@ -155,6 +156,7 @@ def start_client(ip_config): ...@@ -155,6 +156,7 @@ def start_client(ip_config):
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR)) assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
def test_serialize(): def test_serialize():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
SERVICE_ID = 12345 SERVICE_ID = 12345
...@@ -173,6 +175,7 @@ def test_serialize(): ...@@ -173,6 +175,7 @@ def test_serialize():
assert res.x == res1.x assert res.x == res1.x
def test_rpc_msg(): def test_rpc_msg():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload, RPCMessage from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload, RPCMessage
SERVICE_ID = 32452 SERVICE_ID = 32452
...@@ -190,6 +193,7 @@ def test_rpc_msg(): ...@@ -190,6 +193,7 @@ def test_rpc_msg():
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_rpc(): def test_rpc():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = open("rpc_ip_config.txt", "w") ip_config = open("rpc_ip_config.txt", "w")
ip_addr = get_local_usable_addr() ip_addr = get_local_usable_addr()
...@@ -199,13 +203,13 @@ def test_rpc(): ...@@ -199,13 +203,13 @@ def test_rpc():
pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config.txt")) pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config.txt"))
pclient = ctx.Process(target=start_client, args=("rpc_ip_config.txt",)) pclient = ctx.Process(target=start_client, args=("rpc_ip_config.txt",))
pserver.start() pserver.start()
time.sleep(1)
pclient.start() pclient.start()
pserver.join() pserver.join()
pclient.join() pclient.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client(): def test_multi_client():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = open("rpc_ip_config_mul_client.txt", "w") ip_config = open("rpc_ip_config_mul_client.txt", "w")
ip_addr = get_local_usable_addr() ip_addr = get_local_usable_addr()
...@@ -227,6 +231,7 @@ def test_multi_client(): ...@@ -227,6 +231,7 @@ def test_multi_client():
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_thread_rpc(): def test_multi_thread_rpc():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = open("rpc_ip_config_multithread.txt", "w") ip_config = open("rpc_ip_config_multithread.txt", "w")
num_servers = 2 num_servers = 2
......
import socket import socket
import os
def get_local_usable_addr(): def get_local_usable_addr():
"""Get local usable IP and port """Get local usable IP and port
...@@ -24,4 +25,11 @@ def get_local_usable_addr(): ...@@ -24,4 +25,11 @@ def get_local_usable_addr():
port = sock.getsockname()[1] port = sock.getsockname()[1]
sock.close() sock.close()
return ip_addr + ' ' + str(port) return ip_addr + ' ' + str(port)
\ No newline at end of file
def reset_envs():
"""Reset common environment variable which are set in tests. """
for key in ['DGL_ROLE', 'DGL_NUM_SAMPLER', 'DGL_NUM_SERVER', 'DGL_DIST_MODE', 'DGL_NUM_CLIENT']:
if key in os.environ:
os.environ.pop(key)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment