/*! * Copyright (c) 2019 by Contributors * \file tp_communicator.h * \brief Tensorpipe Communicator for DGL distributed training. */ #ifndef DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_ #define DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_ #include #include #include #include #include #include #include #include #include #include "./queue.h" namespace dgl { namespace rpc { class RPCMessage; typedef Queue RPCMessageQueue; /*! * \brief TPSender for DGL distributed training. * * TPSender is the communicator implemented by tcp socket. */ class TPSender { public: /*! * \brief Sender constructor * \param queue_size size of message queue */ explicit TPSender(std::shared_ptr ctx) { CHECK(ctx) << "Context is not initialized"; this->context = ctx; } /*! * \brief Sender destructor */ ~TPSender() { Finalize(); } /*! * \brief Connect to receiver with address and ID * \param addr Networking address, e.g., 'tcp://127.0.0.1:50091' * \param recv_id receiver's ID * \return True for success and False for fail * * ConnectReceiver() is not thread-safe and only one thread can invoke this API. */ bool ConnectReceiver(const std::string& addr, int recv_id); /*! * \brief Send RPCMessage to specified Receiver. * \param msg data message \param recv_id receiver's ID */ void Send(const RPCMessage& msg, int recv_id); /*! * \brief Finalize TPSender */ void Finalize(); /*! * \brief Communicator type: 'tp' */ inline std::string Type() const { return std::string("tp"); } private: /*! * \brief global context of tensorpipe */ std::shared_ptr context; /*! * \brief pipe for each connection of receiver */ std::unordered_map> pipes_; /*! * \brief receivers' listening address */ std::unordered_map receiver_addrs_; }; /*! * \brief TPReceiver for DGL distributed training. * * Tensorpipe Receiver is the communicator implemented by tcp socket. */ class TPReceiver { public: /*! * \brief Receiver constructor * \param queue_size size of message queue. */ explicit TPReceiver(std::shared_ptr ctx) { CHECK(ctx) << "Context is not initialized"; this->context = ctx; queue_ = std::make_shared(); } /*! * \brief Receiver destructor */ ~TPReceiver() { Finalize(); } /*! * \brief Wait for all the Senders to connect * \param addr Networking address, e.g., 'tcp://127.0.0.1:50051' * \param num_sender total number of Senders * \param blocking whether to wait blockingly * \return True for success and False for fail * * Wait() is not thread-safe and only one thread can invoke this API. */ bool Wait(const std::string &addr, int num_sender, bool blocking = true); /*! * \brief Recv RPCMessage from Sender. Actually removing data from queue. * \param msg pointer of RPCmessage * \param send_id which sender current msg comes from * \return Status code * * (1) The Recv() API is blocking, which will not * return until getting data from message queue. * (2) The Recv() API is thread-safe. * (3) Memory allocated by communicator but will not own it after the function * returns. */ void Recv(RPCMessage* msg); /*! * \brief Finalize SocketReceiver * * Finalize() is not thread-safe and only one thread can invoke this API. */ void Finalize(); /*! * \brief Communicator type: 'tp' (tensorpipe) */ inline std::string Type() const { return std::string("tp"); } /*! * \brief Issue a receive request on pipe, and push the result into queue */ static void ReceiveFromPipe(std::shared_ptr pipe, std::shared_ptr queue); private: /*! * \brief Callback for new connection is accepted. */ void OnAccepted(const tensorpipe::Error&, std::shared_ptr); private: /*! * \brief number of sender */ int num_sender_; /*! * \brief listener to build pipe */ std::shared_ptr listener; /*! * \brief global context of tensorpipe */ std::shared_ptr context; /*! * \brief pipe for each client connections */ std::unordered_map> pipes_; /*! * \brief RPCMessage queue */ std::shared_ptr queue_; /*! * \brief number of accepted connections */ std::atomic num_connected_{0}; /*! * \brief listner */ std::shared_ptr listener_{nullptr}; }; } // namespace rpc } // namespace dgl #endif // DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_