Unverified Commit 72ec1c95 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistDGL] remove tensorpipe cpp code (#5850)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent df97f2e8
...@@ -147,7 +147,7 @@ def create_sender(max_queue_size): ...@@ -147,7 +147,7 @@ def create_sender(max_queue_size):
Maximal size (bytes) of network queue buffer. Maximal size (bytes) of network queue buffer.
""" """
max_thread_count = int(os.getenv("DGL_SOCKET_MAX_THREAD_COUNT", "0")) max_thread_count = int(os.getenv("DGL_SOCKET_MAX_THREAD_COUNT", "0"))
_CAPI_DGLRPCCreateSender(int(max_queue_size), "socket", max_thread_count) _CAPI_DGLRPCCreateSender(int(max_queue_size), max_thread_count)
def create_receiver(max_queue_size): def create_receiver(max_queue_size):
...@@ -159,7 +159,7 @@ def create_receiver(max_queue_size): ...@@ -159,7 +159,7 @@ def create_receiver(max_queue_size):
Maximal size (bytes) of network queue buffer. Maximal size (bytes) of network queue buffer.
""" """
max_thread_count = int(os.getenv("DGL_SOCKET_MAX_THREAD_COUNT", "0")) max_thread_count = int(os.getenv("DGL_SOCKET_MAX_THREAD_COUNT", "0"))
_CAPI_DGLRPCCreateReceiver(int(max_queue_size), "socket", max_thread_count) _CAPI_DGLRPCCreateReceiver(int(max_queue_size), max_thread_count)
def finalize_sender(): def finalize_sender():
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <dgl/zerocopy_serializer.h> #include <dgl/zerocopy_serializer.h>
#include <tensorpipe/tensorpipe.h>
#include <unistd.h> #include <unistd.h>
#include <csignal> #include <csignal>
...@@ -27,42 +26,11 @@ using namespace dgl::runtime; ...@@ -27,42 +26,11 @@ using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace rpc { namespace rpc {
using namespace tensorpipe;
// Borrow from PyTorch // Borrow from PyTorch
const char kSocketIfnameEnvVar[] = "TP_SOCKET_IFNAME"; const char kSocketIfnameEnvVar[] = "TP_SOCKET_IFNAME";
const char kDefaultUvAddress[] = "127.0.0.1"; const char kDefaultUvAddress[] = "127.0.0.1";
const std::string& guessAddress() {
static const std::string uvAddress = []() {
tensorpipe::Error error;
std::string result;
char* ifnameEnv = std::getenv(kSocketIfnameEnvVar);
if (ifnameEnv != nullptr) {
std::tie(error, result) =
tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
if (error) {
LOG(WARNING) << "Failed to look up the IP address for interface "
<< ifnameEnv << " (" << error.what() << "), defaulting to "
<< kDefaultUvAddress;
return std::string(kDefaultUvAddress);
}
} else {
std::tie(error, result) =
tensorpipe::transport::uv::lookupAddrForHostname();
if (error) {
LOG(WARNING) << "Failed to look up the IP address for the hostname ("
<< error.what() << "), defaulting to "
<< kDefaultUvAddress;
return std::string(kDefaultUvAddress);
}
}
return result;
}();
return uvAddress;
}
RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) { RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {
RPCContext::getInstance()->sender->Send(msg, target_id); RPCContext::getInstance()->sender->Send(msg, target_id);
return kRPCSuccess; return kRPCSuccess;
...@@ -87,38 +55,6 @@ RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) { ...@@ -87,38 +55,6 @@ RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
return status; return status;
} }
void InitGlobalTpContext() {
if (!RPCContext::getInstance()->ctx) {
RPCContext::getInstance()->ctx = std::make_shared<tensorpipe::Context>();
auto context = RPCContext::getInstance()->ctx;
auto transportContext = tensorpipe::transport::uv::create();
auto shmtransport = tensorpipe::transport::shm::create();
context->registerTransport(0 /* priority */, "tcp", transportContext);
// Register basic uv channel
auto basicChannel = tensorpipe::channel::basic::create();
context->registerChannel(0 /* low priority */, "basic", basicChannel);
char* numUvThreads_str = std::getenv("DGL_SOCKET_NTHREADS");
if (numUvThreads_str) {
int numUvThreads = std::atoi(numUvThreads_str);
CHECK(numUvThreads > 0)
<< "DGL_SOCKET_NTHREADS should be positive integer if set";
// Register multiplex uv channel
std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
for (int i = 0; i < numUvThreads; i++) {
auto context = tensorpipe::transport::uv::create();
std::string address = guessAddress();
contexts.push_back(std::move(context));
listeners.push_back(contexts.back()->listen(address));
}
auto mptChannel = tensorpipe::channel::mpt::create(
std::move(contexts), std::move(listeners));
context->registerChannel(20 /* high priority */, "mpt", mptChannel);
}
}
}
//////////////////////////// C APIs //////////////////////////// //////////////////////////// C APIs ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
.set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); }); .set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });
...@@ -126,41 +62,17 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset") ...@@ -126,41 +62,17 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0]; int64_t msg_queue_size = args[0];
std::string type = args[1]; int max_thread_count = args[1];
int max_thread_count = args[2]; RPCContext::getInstance()->sender.reset(
if (type == "tensorpipe") { new network::SocketSender(msg_queue_size, max_thread_count));
InitGlobalTpContext();
RPCContext::getInstance()->sender.reset(
new TPSender(RPCContext::getInstance()->ctx));
} else if (type == "socket") {
RPCContext::getInstance()->sender.reset(
new network::SocketSender(msg_queue_size, max_thread_count));
} else {
LOG(FATAL) << "Unknown communicator type for rpc sender: " << type;
}
LOG(INFO) << "Sender with NetType~"
<< RPCContext::getInstance()->sender->NetType()
<< " is created.";
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0]; int64_t msg_queue_size = args[0];
std::string type = args[1]; int max_thread_count = args[1];
int max_thread_count = args[2]; RPCContext::getInstance()->receiver.reset(
if (type == "tensorpipe") { new network::SocketReceiver(msg_queue_size, max_thread_count));
InitGlobalTpContext();
RPCContext::getInstance()->receiver.reset(
new TPReceiver(RPCContext::getInstance()->ctx));
} else if (type == "socket") {
RPCContext::getInstance()->receiver.reset(
new network::SocketReceiver(msg_queue_size, max_thread_count));
} else {
LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type;
}
LOG(INFO) << "Receiver with NetType~"
<< RPCContext::getInstance()->receiver->NetType()
<< " is created.";
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include "./server_state.h" #include "./server_state.h"
#include "net_type.h" #include "net_type.h"
#include "network/socket_communicator.h" #include "network/socket_communicator.h"
#include "tensorpipe/tp_communicator.h"
namespace dgl { namespace dgl {
namespace rpc { namespace rpc {
...@@ -89,11 +88,6 @@ struct RPCContext { ...@@ -89,11 +88,6 @@ struct RPCContext {
*/ */
std::shared_ptr<RPCReceiver> receiver; std::shared_ptr<RPCReceiver> receiver;
/**
* @brief Tensorpipe global context
*/
std::shared_ptr<tensorpipe::Context> ctx;
/** /**
* @brief Server state data. * @brief Server state data.
* *
...@@ -131,7 +125,6 @@ struct RPCContext { ...@@ -131,7 +125,6 @@ struct RPCContext {
t->num_servers_per_machine = 0; t->num_servers_per_machine = 0;
t->sender.reset(); t->sender.reset();
t->receiver.reset(); t->receiver.reset();
t->ctx.reset();
t->server_state.reset(); t->server_state.reset();
t->group_id = -1; t->group_id = -1;
t->curr_client_id = -1; t->curr_client_id = -1;
......
# Introduction to tensorpipe
## Process of setup communication:
```cpp
context = std::make_shared<tensorpipe::Context>();
// For Receiver
// Create listener to accept join request
listener = context->listen({addr});
// Accept join request and generate pipe
std::promise<std::shared_ptr<Pipe>> pipeProm;
listener->accept([&](const Error& error, std::shared_ptr<Pipe> pipe) {
if (error) {
LOG(WARNING) << error.what();
}
pipeProm.set_value(std::move(pipe));
});
std::shared_ptr<Pipe> pipe = pipeProm.get_future().get();
// For Sender
pipe = context->connect(addr);
// Note that the pipe may not be really available at this point
// For example if no listener listening the address, there won't be error raised
// The error will happen at the write/read operation. Thus we need to manually check this
std::promise<bool> done;
tensorpipe::Message tpmsg;
tpmsg.metadata = "dglconnect";
pipe->write(tpmsg, [&done](const tensorpipe::Error& error) {
if (error) {
done.set_value(false);
} else {
done.set_value(true);
}
});
if (done.get_future().get()) {
break;
} else {
sleep(5);
LOG(INFO) << "Cannot connect to remove server. Wait to retry";
}
```
## Read and Write
Message structure: https://github.com/pytorch/tensorpipe/blob/master/tensorpipe/core/message.h
There are three concepts, Message, Descriptor and Allocation.
Message is the core struct for communication. Message contains three major field, metadata(string), payload(cpu memory buffers), tensor(cpu/gpu memory buffer, with device as attribute).
Descriptor and Allocation are for the read scenario. A typical read operation as follows
```cpp
pipe->readDescriptor(
[](const Error& error, Descriptor descriptor) {
// Descriptor contains metadata of the message, the data size of each payload, the device information of tensors and other metadatas other than the real buffer
// User should allocate the proper memory based on the descriptor, and set back the allocated memory to Allocation object
Allocation allocation;
// Then call pipe->read to ask pipe to receive the real buffer into allocations
pipe->read(allocation, [](const Error& error) {});
});
```
To send the message is much simpler
```cpp
// Resource cleaning should be handled in the callback
pipe->write(message, callback_fn)
```
## Register the underlying communication channel
There are two concept, transport and channel.
Transport is the basic component for communication like sockets, which only supports cpu buffers.
Channel is higher abstraction over transport, which can support gpu buffers, or utilize multiple transport method to acceelerate communication
Tensorpipe will try to setup the channel based on priority.
```cpp
// Register transport
auto context = std::make_shared<tensorpipe::Context>();
// uv is short for libuv, using epoll with sockets to communicate
auto transportContext = tensorpipe::transport::uv::create();
context->registerTransport(0 /* priority */, "tcp", transportContext);/
// basic channel just use the bare transport to communicate
auto basicChannel = tensorpipe::channel::basic::create();
context->registerChannel(0, "basic", basicChannel);
// Below is the mpt(multiplex transport) channel, which can use multiple uv transport to increase throughput
std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts = {
tensorpipe::transport::uv::create(), tensorpipe::transport::uv::create(),
tensorpipe::transport::uv::create()};
std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners = {
contexts[0]->listen("127.0.0.1"), contexts[1]->listen("127.0.0.1"),
contexts[2]->listen("127.0.0.1")};
auto mptChannel = tensorpipe::channel::mpt::create(
std::move(contexts), std::move(listeners));
context->registerChannel(10, "mpt", mptChannel);
```
There are more channels supported by tensorpipe, such as CUDA IPC (for cuda communication on the same machine), CMA(using shared memory on the same machine), CUDA GDR(using infiniband with CUDA GPUDirect for gpu buffer), CUDA Basic(using socket+seperate thread to copy buffer to CUDA memory.
Quote from tensorpipe:
Backends come in two flavors:
Transports are the connections used by the pipes to transfer control messages, and the (smallish) core payloads. They are meant to be lightweight and low-latency. The most basic transport is a simple TCP one, which should work in all scenarios. A more optimized one, for example, is based on a ring buffer allocated in shared memory, which two processes on the same machine can use to communicate by performing just a memory copy, without passing through the kernel.
Channels are where the heavy lifting takes place, as they take care of copying the (larger) tensor data. High bandwidths are a requirement. Examples include multiplexing chunks of data across multiple TCP sockets and processes, so to saturate the NIC's bandwidth. Or using a CUDA memcpy call to transfer memory from one GPU to another using NVLink.
\ No newline at end of file
/**
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#ifndef DGL_RPC_TENSORPIPE_QUEUE_H_
#define DGL_RPC_TENSORPIPE_QUEUE_H_
#include <dmlc/logging.h>
#include <chrono>
#include <condition_variable>
#include <deque>
#include <mutex>
#include <utility>
namespace dgl {
namespace rpc {
template <typename T>
class Queue {
public:
// Capacity isn't used actually
explicit Queue(int capacity = 1) : capacity_(capacity) {}
void push(T t) {
std::unique_lock<std::mutex> lock(mutex_);
// while (items_.size() >= capacity_) {
// cv_.wait(lock);
// }
items_.push_back(std::move(t));
cv_.notify_all();
}
bool pop(T *msg, int timeout) {
std::unique_lock<std::mutex> lock(mutex_);
if (timeout == 0) {
DLOG(WARNING) << "Will wait infinitely until message is popped...";
cv_.wait(lock, [this] { return items_.size() > 0; });
} else {
if (!cv_.wait_for(lock, std::chrono::milliseconds(timeout), [this] {
return items_.size() > 0;
})) {
DLOG(WARNING) << "Times out for popping message after " << timeout
<< " milliseconds.";
return false;
}
}
*msg = std::move(items_.front());
items_.pop_front();
cv_.notify_all();
return true;
}
private:
std::mutex mutex_;
std::condition_variable cv_;
const int capacity_;
std::deque<T> items_;
};
} // namespace rpc
} // namespace dgl
#endif // DGL_RPC_TENSORPIPE_QUEUE_H_
/**
* Copyright (c) 2019 by Contributors
* @file tp_communicator.cc
* @brief Tensorpipe Communicator for DGL distributed training.
*/
#include "tp_communicator.h"
#include <time.h>
#include <unistd.h>
#include <future>
#include <memory>
#include <utility>
#include "../rpc.h"
namespace dgl {
namespace rpc {
using namespace tensorpipe;
bool TPSender::ConnectReceiver(const std::string &addr, int recv_id) {
if (pipes_.find(recv_id) != pipes_.end()) {
LOG(WARNING) << "Duplicate recv_id[" << recv_id << "]. Ignoring...";
return true;
}
std::shared_ptr<Pipe> pipe;
pipe = context->connect(addr);
auto done = std::make_shared<std::promise<bool>>();
tensorpipe::Message tpmsg;
tpmsg.metadata = "dglconnect";
pipe->write(tpmsg, [done](const tensorpipe::Error &error) {
done->set_value(!error);
});
if (!done->get_future().get()) {
DLOG(WARNING) << "Failed to connect to receiver[" << addr << "].";
return false;
}
pipes_[recv_id] = pipe;
return true;
}
void TPSender::Send(const RPCMessage &msg, int recv_id) {
auto pipe = pipes_[recv_id];
tensorpipe::Message tp_msg;
std::string *zerocopy_blob_ptr = &tp_msg.metadata;
StreamWithBuffer zc_write_strm(zerocopy_blob_ptr, true);
zc_write_strm.Write(msg);
int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
zerocopy_blob_ptr->append(
reinterpret_cast<char *>(&nonempty_ndarray_count), sizeof(int32_t));
tp_msg.tensors.resize(nonempty_ndarray_count);
// Hold the NDArray that ensure it's valid until write operation completes
auto ndarray_holder = std::make_shared<std::vector<NDArray>>();
ndarray_holder->resize(nonempty_ndarray_count);
auto &buffer_list = zc_write_strm.buffer_list();
for (size_t i = 0; i < buffer_list.size(); i++) {
auto &ptr = buffer_list[i];
(*ndarray_holder.get())[i] = ptr.tensor;
tensorpipe::CpuBuffer cpu_buffer;
cpu_buffer.ptr = ptr.data;
tp_msg.tensors[i].buffer = cpu_buffer;
tp_msg.tensors[i].length = ptr.size;
if (ptr.size == 0) {
LOG(FATAL) << "Cannot send a empty NDArray.";
}
}
// Let's write blockingly in case of congestion in underlying transports.
auto done = std::make_shared<std::promise<void>>();
pipe->write(
tp_msg, [ndarray_holder, recv_id, done](const tensorpipe::Error &error) {
if (error) {
LOG(FATAL) << "Failed to send message to " << recv_id
<< ". Details: " << error.what();
}
done->set_value();
});
done->get_future().wait();
}
void TPSender::Finalize() {
for (auto &&p : pipes_) {
if (p.second) {
p.second->close();
}
}
pipes_.clear();
}
void TPReceiver::Finalize() {
if (listener_) {
listener_->close();
}
for (auto &&p : pipes_) {
if (p.second) {
p.second->close();
}
}
pipes_.clear();
}
bool TPReceiver::Wait(const std::string &addr, int num_sender, bool blocking) {
if (listener_) {
LOG(WARNING) << "TPReceiver::Wait() has been called already. Ignoring...";
return true;
}
LOG(INFO) << "TPReceiver starts to wait on [" << addr << "].";
listener_ = context->listen({addr});
listener_->accept([this](const Error &error, std::shared_ptr<Pipe> pipe) {
OnAccepted(error, pipe);
});
while (blocking && (num_sender != num_connected_)) {
}
return true;
}
void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
if (error) {
if (error.isOfType<ListenerClosedError>()) {
// Expected.
} else {
LOG(WARNING) << "Unexpected error when accepting incoming pipe: "
<< error.what();
}
return;
}
// Accept the next connection request
listener_->accept([this](const Error &error, std::shared_ptr<Pipe> pipe) {
OnAccepted(error, pipe);
});
// read the handshake message: "dglconnect"
pipe->readDescriptor([pipe, this](const Error &error, Descriptor descriptor) {
if (error) {
LOG(ERROR) << "Unexpected error when reading from accepted pipe: "
<< error.what();
return;
}
Allocation allocation;
pipe->read(allocation, [](const Error &error) {});
CHECK(descriptor.metadata == "dglconnect") << "Invalid connect message.";
pipes_[num_connected_] = pipe;
ReceiveFromPipe(pipe, queue_);
++num_connected_;
});
}
void TPReceiver::ReceiveFromPipe(
std::shared_ptr<Pipe> pipe, std::shared_ptr<RPCMessageQueue> queue) {
pipe->readDescriptor([pipe, queue = std::move(queue)](
const Error &error, Descriptor descriptor) {
if (error) {
// Error may happen when the pipe is closed
return;
}
Allocation allocation;
CHECK_EQ(descriptor.payloads.size(), 0) << "Invalid DGL RPC Message";
int tensorsize = descriptor.tensors.size();
if (tensorsize > 0) {
allocation.tensors.resize(tensorsize);
for (size_t i = 0; i < descriptor.tensors.size(); i++) {
tensorpipe::CpuBuffer cpu_buffer;
cpu_buffer.ptr = new char[descriptor.tensors[i].length];
allocation.tensors[i].buffer = cpu_buffer;
}
}
pipe->read(
allocation, [allocation, descriptor = std::move(descriptor),
queue = std::move(queue), pipe](const Error &error) {
if (error) {
// Because we always have a read event posted to the epoll,
// Therefore when pipe is closed, error will be raised.
// But this error is expected.
// Other error is not expected. But we cannot identify the error
// with each Other for now. Thus here we skip handling for all
// errors
return;
}
char *meta_msg_begin = const_cast<char *>(&descriptor.metadata[0]);
std::vector<void *> buffer_list(descriptor.tensors.size());
for (size_t i = 0; i < descriptor.tensors.size(); i++) {
buffer_list[i] =
allocation.tensors[i].buffer.unwrap<CpuBuffer>().ptr;
}
StreamWithBuffer zc_read_strm(
meta_msg_begin, descriptor.metadata.size() - sizeof(int32_t),
buffer_list);
RPCMessage msg;
zc_read_strm.Read(&msg);
queue->push(msg);
TPReceiver::ReceiveFromPipe(pipe, queue);
});
});
}
RPCStatus TPReceiver::Recv(RPCMessage *msg, int timeout) {
return queue_->pop(msg, timeout) ? kRPCSuccess : kRPCTimeOut;
}
} // namespace rpc
} // namespace dgl
/**
* Copyright (c) 2019 by Contributors
* @file tp_communicator.h
* @brief Tensorpipe Communicator for DGL distributed training.
*/
#ifndef DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_
#define DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_
#include <dmlc/logging.h>
#include <tensorpipe/tensorpipe.h>
#include <atomic>
#include <deque>
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>
#include "../net_type.h"
#include "./queue.h"
namespace dgl {
namespace rpc {
typedef Queue<RPCMessage> RPCMessageQueue;
/**
* @brief TPSender for DGL distributed training.
*
* TPSender is the communicator implemented by tcp socket.
*/
class TPSender : public RPCSender {
public:
/**
* @brief Sender constructor
* @param queue_size size of message queue
*/
explicit TPSender(std::shared_ptr<tensorpipe::Context> ctx) {
CHECK(ctx) << "Context is not initialized";
this->context = ctx;
}
/**
* @brief Sender destructor
*/
~TPSender() { Finalize(); }
/**
* @brief Connect to a receiver.
*
* When there are multiple receivers to be connected, application will call
* `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
* sure that either all the connections are successfully established or some
* of them fail.
*
* @param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* @param recv_id receiver's ID
* @return True for success and False for fail
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
bool ConnectReceiver(const std::string& addr, int recv_id) override;
/**
* @brief Send RPCMessage to specified Receiver.
* @param msg data message
* @param recv_id receiver's ID
*/
void Send(const RPCMessage& msg, int recv_id) override;
/**
* @brief Finalize TPSender
*/
void Finalize() override;
/**
* @brief Communicator type: 'tp'
*/
const std::string& NetType() const override {
static const std::string net_type = "tensorpipe";
return net_type;
}
private:
/**
* @brief global context of tensorpipe
*/
std::shared_ptr<tensorpipe::Context> context;
/**
* @brief pipe for each connection of receiver
*/
std::unordered_map<int /* receiver ID */, std::shared_ptr<tensorpipe::Pipe>>
pipes_;
/**
* @brief receivers' listening address
*/
std::unordered_map<int /* receiver ID */, std::string> receiver_addrs_;
};
/**
* @brief TPReceiver for DGL distributed training.
*
* Tensorpipe Receiver is the communicator implemented by tcp socket.
*/
class TPReceiver : public RPCReceiver {
public:
/**
* @brief Receiver constructor
* @param queue_size size of message queue.
*/
explicit TPReceiver(std::shared_ptr<tensorpipe::Context> ctx) {
CHECK(ctx) << "Context is not initialized";
this->context = ctx;
queue_ = std::make_shared<RPCMessageQueue>();
}
/**
* @brief Receiver destructor
*/
~TPReceiver() { Finalize(); }
/**
* @brief Wait for all the Senders to connect
* @param addr Networking address, e.g., 'tcp://127.0.0.1:50051'
* @param num_sender total number of Senders
* @param blocking whether to wait blockingly
* @return True for success and False for fail
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
bool Wait(
const std::string& addr, int num_sender, bool blocking = true) override;
/**
* @brief Recv RPCMessage from Sender. Actually removing data from queue.
* @param msg pointer of RPCmessage
* @param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* @return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/
RPCStatus Recv(RPCMessage* msg, int timeout) override;
/**
* @brief Finalize SocketReceiver
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
void Finalize() override;
/**
* @brief Communicator type: 'tp' (tensorpipe)
*/
const std::string& NetType() const override {
static const std::string net_type = "tensorpipe";
return net_type;
}
/**
* @brief Issue a receive request on pipe, and push the result into queue
*/
static void ReceiveFromPipe(
std::shared_ptr<tensorpipe::Pipe> pipe,
std::shared_ptr<RPCMessageQueue> queue);
private:
/**
* @brief Callback for new connection is accepted.
*/
void OnAccepted(const tensorpipe::Error&, std::shared_ptr<tensorpipe::Pipe>);
private:
/**
* @brief number of sender
*/
int num_sender_;
/**
* @brief listener to build pipe
*/
std::shared_ptr<tensorpipe::Listener> listener;
/**
* @brief global context of tensorpipe
*/
std::shared_ptr<tensorpipe::Context> context;
/**
* @brief pipe for each client connections
*/
std::unordered_map<
int /* Sender (virutal) ID */, std::shared_ptr<tensorpipe::Pipe>>
pipes_;
/**
* @brief RPCMessage queue
*/
std::shared_ptr<RPCMessageQueue> queue_;
/**
* @brief number of accepted connections
*/
std::atomic<int32_t> num_connected_{0};
/**
* @brief listner
*/
std::shared_ptr<tensorpipe::Listener> listener_{nullptr};
};
} // namespace rpc
} // namespace dgl
#endif // DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment