Unverified Commit 37467e25 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature][Dist] change TP::Receiver/TP::Sender for multiple connections (#3574)



* [Feature] enable TP::Receiver wait for any numbers of senders

* fix random unit test failure

* avoid endless future wait

* fix unit test failure

* fix seg fault when finalize wait in receiver

* [Feature] refactor sender connect logic and remove unnecessary sleeps in unit tests

* fix lint

* release RPCContext resources before process exits

* [Debug] TPReceiver wait start log

* [Debug] add log in get port

* [Debug] add log

* [ReDebug] revert time sleep in unit tests

* [Debug] remove sleep for test_distri,test_mp

* [debug] add more log

* [debug] add listen_booted_ flag

* [debug] restore commented code for queue

* [debug] sleep more in rpc_client

* restore change in tests

* Revert "restore change in tests"

This reverts commit 41a18926d181ec2517069389bfc41de2cc949280.

* Revert "[debug] sleep more in rpc_client"

This reverts commit a908e758eabca0a6ce62eb2e59baea02a840ac67.

* Revert "[debug] restore commented code for queue"

This reverts commit d3f993b3746e6bb6e2cc2f90204dd7e9461c6301.

* Revert "[debug] add listen_booted_ flag"

This reverts commit 244b2167d94942ff2a0acec8823b974975e52580.

* Revert "[debug] add more log"

This reverts commit 4b78447b0a575a824821dc7e25cca2246e6e30e2.

* Revert "[Debug] remove sleep for test_distri,test_mp"

This reverts commit e1df1aadcc8b1c2a0013ed77322ac391a8807612.

* remove debug code

* revert unnecessary change

* revert unnecessary changes

* always reset RPCContext when get started and reset all data

* remove time.sleep in dist tests

* fix lint

* reset envs before each dist test

* reset env properly

* add time sleep when start each server

* sleep for a while when boot server

* replace wait_thread with callback

* fix lint

* add dglconnect handshake check
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 95c0ff63
......@@ -226,6 +226,7 @@ def initialize(ip_config, num_servers=1, num_workers=0,
'Please define DGL_CONF_PATH to run DistGraph server'
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',')
formats = [f.strip() for f in formats]
rpc.reset()
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
......
......@@ -13,7 +13,7 @@ from .. import backend as F
__all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \
'receiver_wait', 'add_receiver_addr', 'sender_connect', 'read_ip_config', \
'receiver_wait', 'connect_receiver', 'read_ip_config', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
......@@ -138,7 +138,7 @@ def finalize_receiver():
"""
_CAPI_DGLRPCFinalizeReceiver()
def receiver_wait(ip_addr, port, num_senders):
def receiver_wait(ip_addr, port, num_senders, blocking=True):
"""Wait all of the senders' connections.
This api will be blocked until all the senders connect to the receiver.
......@@ -151,11 +151,13 @@ def receiver_wait(ip_addr, port, num_senders):
receiver's port
num_senders : int
total number of senders
blocking : bool
whether to wait blockingly
"""
_CAPI_DGLRPCReceiverWait(ip_addr, int(port), int(num_senders))
_CAPI_DGLRPCReceiverWait(ip_addr, int(port), int(num_senders), blocking)
def add_receiver_addr(ip_addr, port, recv_id):
"""Add Receiver's IP address to sender's namebook.
def connect_receiver(ip_addr, port, recv_id):
"""Connect to target receiver
Parameters
----------
......@@ -166,12 +168,7 @@ def add_receiver_addr(ip_addr, port, recv_id):
recv_id : int
receiver's ID
"""
_CAPI_DGLRPCAddReceiver(ip_addr, int(port), int(recv_id))
def sender_connect():
"""Connect to all the receivers.
"""
_CAPI_DGLRPCSenderConnect()
return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(recv_id))
def set_rank(rank):
"""Set the rank of this process.
......
......@@ -4,6 +4,7 @@ import os
import socket
import atexit
import logging
import time
from . import rpc
from .constants import MAX_QUEUE_SIZE
......@@ -161,17 +162,17 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net
for server_id, addr in server_namebook.items():
server_ip = addr[1]
server_port = addr[2]
rpc.add_receiver_addr(server_ip, server_port, server_id)
rpc.sender_connect()
while not rpc.connect_receiver(server_ip, server_port, server_id):
time.sleep(1)
# Get local usable IP address and port
ip_addr = get_local_usable_addr(server_ip)
client_ip, client_port = ip_addr.split(':')
# wait server connect back
rpc.receiver_wait(client_ip, client_port, num_servers, blocking=False)
# Register client on server
register_req = rpc.ClientRegisterRequest(ip_addr)
for server_id in range(num_servers):
rpc.send_request(server_id, register_req)
# wait server connect back
rpc.receiver_wait(client_ip, client_port, num_servers)
# recv client ID from server
res = rpc.recv_response()
rpc.set_rank(res.client_id)
......
"""Functions used by server."""
import time
from . import rpc
from .constants import MAX_QUEUE_SIZE
......@@ -64,24 +62,23 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
# wait all the senders connect to server.
# Once all the senders connect to server, server will not
# accept new sender's connection
print("Wait connections ...")
rpc.receiver_wait(ip_addr, port, num_clients)
print("%d clients connected!" % num_clients)
print("Wait connections non-blockingly...")
rpc.receiver_wait(ip_addr, port, num_clients, blocking=False)
rpc.set_num_client(num_clients)
# Recv all the client's IP and assign ID to clients
addr_list = []
client_namebook = {}
for _ in range(num_clients):
# blocked until request is received
req, _ = rpc.recv_request()
assert isinstance(req, rpc.ClientRegisterRequest)
addr_list.append(req.ip_addr)
addr_list.sort()
for client_id, addr in enumerate(addr_list):
client_namebook[client_id] = addr
for client_id, addr in client_namebook.items():
client_ip, client_port = addr.split(':')
rpc.add_receiver_addr(client_ip, client_port, client_id)
time.sleep(3) # wait client's socket ready. 3 sec is enough.
rpc.sender_connect()
assert rpc.connect_receiver(client_ip, client_port, client_id)
if rpc.get_rank() == 0: # server_0 send all the IDs
for client_id, _ in client_namebook.items():
register_res = rpc.ClientRegisterResponse(client_id)
......
......@@ -143,28 +143,22 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReceiverWait")
std::string ip = args[0];
int port = args[1];
int num_sender = args[2];
bool blocking = args[3];
std::string addr;
addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
if (RPCContext::getInstance()->receiver->Wait(addr, num_sender) == false) {
if (RPCContext::getInstance()->receiver->Wait(addr, num_sender, blocking) == false) {
LOG(FATAL) << "Wait sender socket failed.";
}
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCAddReceiver")
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
.set_body([](DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0];
int port = args[1];
int recv_id = args[2];
std::string addr;
addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
RPCContext::getInstance()->sender->AddReceiver(addr, recv_id);
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSenderConnect")
.set_body([](DGLArgs args, DGLRetValue* rv) {
if (RPCContext::getInstance()->sender->Connect() == false) {
LOG(FATAL) << "Sender connection failed.";
}
*rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
......
......@@ -113,12 +113,15 @@ struct RPCContext {
t->rank = -1;
t->machine_id = -1;
t->num_machines = 0;
t->msg_seq = 0;
t->num_servers = 0;
t->num_clients = 0;
t->barrier_count = 0;
t->num_servers_per_machine = 0;
t->sender.reset();
t->receiver.reset();
t->ctx.reset();
t->server_state.reset();
}
};
......
......@@ -20,54 +20,48 @@ namespace rpc {
using namespace tensorpipe;
void TPSender::AddReceiver(const std::string& addr, int recv_id) {
receiver_addrs_[recv_id] = addr;
}
bool TPSender::Connect() {
for (const auto& kv : receiver_addrs_) {
bool TPSender::ConnectReceiver(const std::string &addr, int recv_id) {
if (pipes_.find(recv_id) != pipes_.end()) {
LOG(WARNING) << "Duplicate recv_id[" << recv_id << "]. Ignoring...";
return true;
}
std::shared_ptr<Pipe> pipe;
for (;;) {
pipe = context->connect(kv.second);
std::promise<bool> done;
pipe = context->connect(addr);
auto done = std::make_shared<std::promise<bool>>();
tensorpipe::Message tpmsg;
tpmsg.metadata = "dglconnect";
pipe->write(tpmsg, [&done](const tensorpipe::Error& error) {
pipe->write(tpmsg, [done](const tensorpipe::Error &error) {
if (error) {
done.set_value(false);
LOG(WARNING) << "Error occurred when write to pipe: " << error.what();
done->set_value(false);
} else {
done.set_value(true);
done->set_value(true);
}
});
if (done.get_future().get()) {
break;
} else {
sleep(5);
LOG(INFO) << "Cannot connect to remove server " << kv.second
<< ". Wait to retry";
}
}
pipes_[kv.first] = pipe;
if (!done->get_future().get()) {
LOG(WARNING) << "Failed to connect to receiver[" << addr << "].";
return false;
}
pipes_[recv_id] = pipe;
return true;
}
void TPSender::Send(const RPCMessage& msg, int recv_id) {
void TPSender::Send(const RPCMessage &msg, int recv_id) {
auto pipe = pipes_[recv_id];
tensorpipe::Message tp_msg;
std::string* zerocopy_blob_ptr = &tp_msg.metadata;
std::string *zerocopy_blob_ptr = &tp_msg.metadata;
StreamWithBuffer zc_write_strm(zerocopy_blob_ptr, true);
zc_write_strm.Write(msg);
int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
zerocopy_blob_ptr->append(reinterpret_cast<char*>(&nonempty_ndarray_count),
zerocopy_blob_ptr->append(reinterpret_cast<char *>(&nonempty_ndarray_count),
sizeof(int32_t));
tp_msg.tensors.resize(nonempty_ndarray_count);
// Hold the NDArray that ensure it's valid until write operation completes
auto ndarray_holder = std::make_shared<std::vector<NDArray>>();
ndarray_holder->resize(nonempty_ndarray_count);
auto& buffer_list = zc_write_strm.buffer_list();
for (int i = 0; i < buffer_list.size(); i++) {
auto& ptr = buffer_list[i];
auto &buffer_list = zc_write_strm.buffer_list();
for (size_t i = 0; i < buffer_list.size(); i++) {
auto &ptr = buffer_list[i];
(*ndarray_holder.get())[i] = ptr.tensor;
tensorpipe::CpuBuffer cpu_buffer;
cpu_buffer.ptr = ptr.data;
......@@ -78,7 +72,7 @@ void TPSender::Send(const RPCMessage& msg, int recv_id) {
}
}
pipe->write(tp_msg,
[ndarray_holder, recv_id](const tensorpipe::Error& error) {
[ndarray_holder, recv_id](const tensorpipe::Error &error) {
if (error) {
LOG(FATAL) << "Failed to send message to " << recv_id
<< ". Details: " << error.what();
......@@ -86,37 +80,69 @@ void TPSender::Send(const RPCMessage& msg, int recv_id) {
});
}
void TPSender::Finalize() {}
void TPReceiver::Finalize() {}
void TPSender::Finalize() {
for (auto &&p : pipes_) {
p.second->close();
}
pipes_.clear();
}
void TPReceiver::Finalize() {
listener_->close();
for (auto &&p : pipes_) {
p.second->close();
}
pipes_.clear();
}
bool TPReceiver::Wait(const std::string& addr, int num_sender) {
listener = context->listen({addr});
for (int i = 0; i < num_sender; i++) {
std::promise<std::shared_ptr<Pipe>> pipeProm;
listener->accept([&](const Error& error, std::shared_ptr<Pipe> pipe) {
bool TPReceiver::Wait(const std::string &addr, int num_sender, bool blocking) {
if (listener_) {
LOG(WARNING) << "TPReceiver::Wait() has been called already. Ignoring...";
return true;
}
LOG(INFO) << "TPReceiver starts to wait on [" << addr << "].";
listener_ = context->listen({addr});
listener_->accept([this](const Error &error, std::shared_ptr<Pipe> pipe) {
OnAccepted(error, pipe);
});
while (blocking && (num_sender != num_connected_)) {
}
return true;
}
void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
if (error) {
LOG(WARNING) << error.what();
if (error.isOfType<ListenerClosedError>()) {
// Expected.
} else {
LOG(WARNING) << "Unexpected error when accepting incoming pipe: " << error.what();
}
pipeProm.set_value(std::move(pipe));
return;
}
// Accept the next connection request
listener_->accept([this](const Error &error, std::shared_ptr<Pipe> pipe) {
OnAccepted(error, pipe);
});
std::shared_ptr<Pipe> pipe = pipeProm.get_future().get();
std::promise<bool> checkConnect;
pipe->readDescriptor(
[pipe, &checkConnect](const Error& error, Descriptor descriptor) {
// read the handshake message: "dglconnect"
pipe->readDescriptor([pipe, this](const Error &error, Descriptor descriptor) {
if (error) {
LOG(WARNING) << "Unexpected error when reading from accepted pipe: " << error.what();
return;
}
Allocation allocation;
checkConnect.set_value(descriptor.metadata == "dglconnect");
pipe->read(allocation, [](const Error& error) {});
});
CHECK(checkConnect.get_future().get()) << "Invalid connect message.";
pipes_[i] = pipe;
pipe->read(allocation, [](const Error &error) {});
CHECK(descriptor.metadata == "dglconnect") << "Invalid connect message.";
pipes_[num_connected_] = pipe;
ReceiveFromPipe(pipe, queue_);
}
return true;
++num_connected_;
});
}
void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
std::shared_ptr<RPCMessageQueue> queue) {
pipe->readDescriptor([pipe, queue = std::move(queue)](const Error& error,
pipe->readDescriptor([pipe, queue = std::move(queue)](const Error &error,
Descriptor descriptor) {
if (error) {
// Error may happen when the pipe is closed
......@@ -128,27 +154,27 @@ void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
int tensorsize = descriptor.tensors.size();
if (tensorsize > 0) {
allocation.tensors.resize(tensorsize);
for (int i = 0; i < descriptor.tensors.size(); i++) {
for (size_t i = 0; i < descriptor.tensors.size(); i++) {
tensorpipe::CpuBuffer cpu_buffer;
cpu_buffer.ptr = new char[descriptor.tensors[i].length];
allocation.tensors[i].buffer = cpu_buffer;
}
}
pipe->read(
allocation, [allocation, descriptor = std::move(descriptor),
queue = std::move(queue), pipe](const Error& error) {
pipe->read(allocation, [allocation, descriptor = std::move(descriptor),
queue = std::move(queue),
pipe](const Error &error) {
if (error) {
// Because we always have a read event posted to the epoll,
// Therefore when pipe is closed, error will be raised.
// But this error is expected.
// Other error is not expected. But we cannot identify the error with each
// Other for now. Thus here we skip handling for all errors
// Other error is not expected. But we cannot identify the error with
// each Other for now. Thus here we skip handling for all errors
return;
}
char* meta_msg_begin = const_cast<char*>(&descriptor.metadata[0]);
std::vector<void*> buffer_list(descriptor.tensors.size());
for (int i = 0; i < descriptor.tensors.size(); i++) {
char *meta_msg_begin = const_cast<char *>(&descriptor.metadata[0]);
std::vector<void *> buffer_list(descriptor.tensors.size());
for (size_t i = 0; i < descriptor.tensors.size(); i++) {
buffer_list[i] = allocation.tensors[i].buffer.unwrap<CpuBuffer>().ptr;
}
StreamWithBuffer zc_read_strm(
......@@ -162,7 +188,7 @@ void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
});
}
void TPReceiver::Recv(RPCMessage* msg) { *msg = std::move(queue_->pop()); }
void TPReceiver::Recv(RPCMessage *msg) { *msg = std::move(queue_->pop()); }
} // namespace rpc
} // namespace dgl
......@@ -15,7 +15,7 @@
#include <thread>
#include <unordered_map>
#include <vector>
#include <atomic>
#include "./queue.h"
namespace dgl {
......@@ -42,21 +42,19 @@ class TPSender {
}
/*!
* \brief Add receiver's address and ID to the sender's namebook
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param id receiver's ID
*
* AddReceiver() is not thread-safe and only one thread can invoke this API.
* \brief Sender destructor
*/
void AddReceiver(const std::string& addr, int recv_id);
~TPSender() { Finalize(); }
/*!
* \brief Connect with all the Receivers
* \brief Connect to receiver with address and ID
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID
* \return True for success and False for fail
*
* Connect() is not thread-safe and only one thread can invoke this API.
* ConnectReceiver() is not thread-safe and only one thread can invoke this API.
*/
bool Connect();
bool ConnectReceiver(const std::string& addr, int recv_id);
/*!
* \brief Send RPCMessage to specified Receiver.
......@@ -109,15 +107,21 @@ class TPReceiver {
queue_ = std::make_shared<RPCMessageQueue>();
}
/*!
* \brief Receiver destructor
*/
~TPReceiver() { Finalize(); }
/*!
* \brief Wait for all the Senders to connect
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50051'
* \param num_sender total number of Senders
* \param blocking whether to wait blockingly
* \return True for success and False for fail
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
bool Wait(const std::string& addr, int num_sender);
bool Wait(const std::string &addr, int num_sender, bool blocking = true);
/*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue.
......@@ -151,6 +155,12 @@ class TPReceiver {
static void ReceiveFromPipe(std::shared_ptr<tensorpipe::Pipe> pipe,
std::shared_ptr<RPCMessageQueue> queue);
private:
/*!
* \brief Callback for new connection is accepted.
*/
void OnAccepted(const tensorpipe::Error&, std::shared_ptr<tensorpipe::Pipe>);
private:
/*!
* \brief number of sender
......@@ -178,6 +188,16 @@ class TPReceiver {
* \brief RPCMessage queue
*/
std::shared_ptr<RPCMessageQueue> queue_;
/*!
* \brief number of accepted connections
*/
std::atomic<int32_t> num_connected_{0};
/*!
* \brief listner
*/
std::shared_ptr<tensorpipe::Listener> listener_{nullptr};
};
} // namespace rpc
......
......@@ -18,6 +18,7 @@ import backend as F
import math
import unittest
import pickle
from utils import reset_envs
if os.name != 'nt':
import fcntl
......@@ -96,7 +97,6 @@ def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
print('end')
def run_client_empty(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
......@@ -140,7 +140,6 @@ def check_server_client_empty(shared_mem, num_servers, num_clients):
print('clients have terminated')
def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
......@@ -149,7 +148,6 @@ def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_ed
check_dist_graph(g, num_clients, num_nodes, num_edges)
def run_emb_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
......@@ -158,7 +156,6 @@ def run_emb_client(graph_name, part_id, server_count, num_clients, num_nodes, nu
check_dist_emb(g, num_clients, num_nodes, num_edges)
def run_client_hierarchy(graph_name, part_id, server_count, node_mask, edge_mask, return_dict):
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
......@@ -440,7 +437,6 @@ def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
def run_client_hetero(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5)
os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
......@@ -587,6 +583,7 @@ def check_server_client_hetero(shared_mem, num_servers, num_clients):
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
def test_server_client():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
check_server_client_hierarchy(False, 1, 4)
check_server_client_empty(True, 1, 1)
......@@ -600,6 +597,7 @@ def test_server_client():
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Mxnet doesn't support distributed DistEmbedding")
def test_dist_emb_server_client():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
check_dist_emb_server_client(True, 1, 1)
check_dist_emb_server_client(False, 1, 1)
......@@ -608,6 +606,7 @@ def test_dist_emb_server_client():
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
def test_standalone():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'standalone'
g = create_random_graph(10000)
......@@ -626,6 +625,7 @@ def test_standalone():
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Mxnet doesn't support distributed DistEmbedding")
def test_standalone_node_emb():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'standalone'
g = create_random_graph(10000)
......
......@@ -10,7 +10,7 @@ import multiprocessing as mp
import numpy as np
import backend as F
import time
from utils import get_local_usable_addr
from utils import get_local_usable_addr, reset_envs
from pathlib import Path
import pytest
from scipy import sparse as spsp
......@@ -93,7 +93,6 @@ def check_rpc_sampling(tmpdir, num_server):
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
print("Done sampling")
for p in pserver_list:
......@@ -129,7 +128,6 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server):
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
eids = F.tensor(np.random.randint(g.number_of_edges(), size=100))
u, v = g.find_edges(orig_eid[eids])
du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
......@@ -179,7 +177,6 @@ def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
eids = F.tensor(np.random.randint(g.number_of_edges('r1'), size=100))
u, v = g.find_edges(orig_eid['r1'][eids], etype='r1')
du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids, etype='r1')
......@@ -194,6 +191,7 @@ def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_find_edges_shuffle(num_server):
reset_envs()
import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -225,7 +223,6 @@ def check_rpc_get_degree_shuffle(tmpdir, num_server):
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_get_degrees.json', i)
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
time.sleep(3)
nids = F.tensor(np.random.randint(g.number_of_nodes(), size=100))
in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(0, tmpdir, num_server > 1, nids)
......@@ -246,6 +243,7 @@ def check_rpc_get_degree_shuffle(tmpdir, num_server):
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_get_degree_shuffle(num_server):
reset_envs()
import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -255,6 +253,7 @@ def test_rpc_get_degree_shuffle(num_server):
#@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skip('Only support partition with shuffle')
def test_rpc_sampling():
reset_envs()
import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -282,7 +281,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server):
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
print("Done sampling")
for p in pserver_list:
......@@ -379,7 +377,6 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
block, gpb = start_hetero_sample_client(0, tmpdir, num_server > 1,
nodes = {'n3': [0, 10, 99, 66, 124, 208]})
print("Done sampling")
......@@ -448,7 +445,6 @@ def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
deg = get_degrees(g, orig_nids['n3'], 'n3')
empty_nids = F.nonzero_1d(deg == 0)
block, gpb = start_hetero_sample_client(0, tmpdir, num_server > 1,
......@@ -480,7 +476,6 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server):
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
fanout = 3
block, gpb = start_hetero_etype_sample_client(0, tmpdir, num_server > 1, fanout,
nodes={'n3': [0, 10, 99, 66, 124, 208]})
......@@ -545,7 +540,6 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
fanout = 3
deg = get_degrees(g, orig_nids['n3'], 'n3')
empty_nids = F.nonzero_1d(deg == 0)
......@@ -564,6 +558,7 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_sampling_shuffle(num_server):
reset_envs()
import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -637,6 +632,7 @@ def check_standalone_etype_sampling_heterograph(tmpdir, reshuffle):
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_standalone_sampling():
reset_envs()
import tempfile
os.environ['DGL_DIST_MODE'] = 'standalone'
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -680,7 +676,6 @@ def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
pserver_list.append(p)
nodes = [0, 10, 99, 66, 1024, 2008]
time.sleep(3)
sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
for p in pserver_list:
p.join()
......@@ -710,6 +705,7 @@ def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_in_subgraph():
reset_envs()
import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -719,6 +715,7 @@ def test_rpc_in_subgraph():
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
def test_standalone_etype_sampling():
reset_envs()
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
os.environ['DGL_DIST_MODE'] = 'standalone'
......
......@@ -9,7 +9,7 @@ import sys
import multiprocessing as mp
import numpy as np
import time
from utils import get_local_usable_addr
from utils import get_local_usable_addr, reset_envs
from pathlib import Path
from dgl.distributed import DistGraphServer, DistGraph, DistDataLoader
import pytest
......@@ -103,6 +103,7 @@ def start_dist_dataloader(rank, tmpdir, num_server, drop_last, orig_nid, orig_ei
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
def test_standalone(tmpdir):
reset_envs()
ip_config = open("mp_ip_config.txt", "w")
for _ in range(1):
ip_config.write('{}\n'.format(get_local_usable_addr()))
......@@ -198,7 +199,6 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers):
p.start()
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
os.environ['DGL_DIST_MODE'] = 'distributed'
os.environ['DGL_NUM_SAMPLER'] = str(num_workers)
ptrainer_list = []
......@@ -206,7 +206,6 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers):
p = ctx.Process(target=start_dist_neg_dataloader, args=(
0, tmpdir, num_server, num_workers, orig_nid, g))
p.start()
time.sleep(1)
ptrainer_list.append(p)
for p in pserver_list:
......@@ -221,6 +220,7 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers):
@pytest.mark.parametrize("drop_last", [True, False])
@pytest.mark.parametrize("reshuffle", [True, False])
def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):
reset_envs()
ip_config = open("mp_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{}\n'.format(get_local_usable_addr()))
......@@ -244,13 +244,11 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
os.environ['DGL_DIST_MODE'] = 'distributed'
os.environ['DGL_NUM_SAMPLER'] = str(num_workers)
ptrainer = ctx.Process(target=start_dist_dataloader, args=(
0, tmpdir, num_server, drop_last, orig_nid, orig_eid))
ptrainer.start()
time.sleep(1)
for p in pserver_list:
p.join()
......@@ -387,7 +385,6 @@ def check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type):
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
os.environ['DGL_DIST_MODE'] = 'distributed'
os.environ['DGL_NUM_SAMPLER'] = str(num_workers)
ptrainer_list = []
......@@ -395,13 +392,11 @@ def check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type):
p = ctx.Process(target=start_node_dataloader, args=(
0, tmpdir, num_server, num_workers, orig_nid, orig_eid, g))
p.start()
time.sleep(1)
ptrainer_list.append(p)
elif dataloader_type == 'edge':
p = ctx.Process(target=start_edge_dataloader, args=(
0, tmpdir, num_server, num_workers, orig_nid, orig_eid, g))
p.start()
time.sleep(1)
ptrainer_list.append(p)
for p in pserver_list:
p.join()
......@@ -430,6 +425,7 @@ def create_random_hetero():
@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"])
def test_dataloader(tmpdir, num_server, num_workers, dataloader_type):
reset_envs()
g = CitationGraphDataset("cora")[0]
check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type)
g = create_random_hetero()
......@@ -441,6 +437,7 @@ def test_dataloader(tmpdir, num_server, num_workers, dataloader_type):
@pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("num_workers", [0, 4])
def test_neg_dataloader(tmpdir, num_server, num_workers):
reset_envs()
g = CitationGraphDataset("cora")[0]
check_neg_dataloader(g, tmpdir, num_server, num_workers)
g = create_random_hetero()
......
......@@ -7,6 +7,7 @@ import backend as F
import unittest, pytest
import multiprocessing as mp
from numpy.testing import assert_array_equal
from utils import reset_envs
if os.name != 'nt':
import fcntl
......@@ -108,8 +109,8 @@ class HelloRequest(dgl.distributed.Request):
return res
def start_server(num_clients, ip_config, server_id=0):
print("Sleep 5 seconds to test client re-connect.")
time.sleep(5)
print("Sleep 2 seconds to test client re-connect.")
time.sleep(2)
server_state = dgl.distributed.ServerState(None, local_g=None, partition_book=None)
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
print("Start server {}".format(server_id))
......@@ -155,6 +156,7 @@ def start_client(ip_config):
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
def test_serialize():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
SERVICE_ID = 12345
......@@ -173,6 +175,7 @@ def test_serialize():
assert res.x == res1.x
def test_rpc_msg():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload, RPCMessage
SERVICE_ID = 32452
......@@ -190,6 +193,7 @@ def test_rpc_msg():
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_rpc():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = open("rpc_ip_config.txt", "w")
ip_addr = get_local_usable_addr()
......@@ -199,13 +203,13 @@ def test_rpc():
pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config.txt"))
pclient = ctx.Process(target=start_client, args=("rpc_ip_config.txt",))
pserver.start()
time.sleep(1)
pclient.start()
pserver.join()
pclient.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = open("rpc_ip_config_mul_client.txt", "w")
ip_addr = get_local_usable_addr()
......@@ -227,6 +231,7 @@ def test_multi_client():
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_thread_rpc():
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = open("rpc_ip_config_multithread.txt", "w")
num_servers = 2
......
import socket
import os
def get_local_usable_addr():
"""Get local usable IP and port
......@@ -25,3 +26,10 @@ def get_local_usable_addr():
sock.close()
return ip_addr + ' ' + str(port)
def reset_envs():
"""Reset common environment variable which are set in tests. """
for key in ['DGL_ROLE', 'DGL_NUM_SAMPLER', 'DGL_NUM_SERVER', 'DGL_DIST_MODE', 'DGL_NUM_CLIENT']:
if key in os.environ:
os.environ.pop(key)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment