/*! * Copyright (c) 2019 by Contributors * \file communicator.h * \brief SocketCommunicator for DGL distributed training. */ #ifndef DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_ #define DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_ #include #include #include #include #include #include "../../runtime/semaphore_wrapper.h" #include "communicator.h" #include "msg_queue.h" #include "tcp_socket.h" #include "common.h" namespace dgl { namespace network { static constexpr int kMaxTryCount = 1024; // maximal connection: 1024 static constexpr int kTimeOut = 10 * 60; // 10 minutes (in seconds) for socket timeout static constexpr int kMaxConnection = 1024; // maximal connection: 1024 /*! * \breif Networking address */ struct IPAddr { std::string ip; int port; }; /*! * \brief SocketSender for DGL distributed training. * * SocketSender is the communicator implemented by tcp socket. */ class SocketSender : public Sender { public: /*! * \brief Sender constructor * \param queue_size size of message queue * \param max_thread_count size of thread pool. 0 for no limit */ SocketSender(int64_t queue_size, int max_thread_count) : Sender(queue_size, max_thread_count) {} /*! * \brief Add receiver's address and ID to the sender's namebook * \param addr Networking address, e.g., 'socket://127.0.0.1:50091', 'mpi://0' * \param id receiver's ID * * AddReceiver() is not thread-safe and only one thread can invoke this API. */ void AddReceiver(const char* addr, int recv_id); /*! * \brief Connect with all the Receivers * \return True for success and False for fail * * Connect() is not thread-safe and only one thread can invoke this API. */ bool Connect(); /*! * \brief Send data to specified Receiver. Actually pushing message to message queue. * \param msg data message * \param recv_id receiver's ID * \return Status code * * (1) The send is non-blocking. There is no guarantee that the message has been * physically sent out when the function returns. * (2) The communicator will assume the responsibility of the given message. * (3) The API is multi-thread safe. * (4) Messages sent to the same receiver are guaranteed to be received in the same order. * There is no guarantee for messages sent to different receivers. */ STATUS Send(Message msg, int recv_id); /*! * \brief Finalize SocketSender * * Finalize() is not thread-safe and only one thread can invoke this API. */ void Finalize(); /*! * \brief Communicator type: 'socket' */ inline std::string Type() const { return std::string("socket"); } private: /*! * \brief socket for each connection of receiver */ std::vector>> sockets_; /*! * \brief receivers' address */ std::unordered_map receiver_addrs_; /*! * \brief message queue for each thread */ std::vector> msg_queue_; /*! * \brief Independent thread */ std::vector> threads_; /*! * \brief Send-loop for each thread * \param sockets TCPSockets for current thread * \param queue message_queue for current thread * * Note that, the SendLoop will finish its loop-job and exit thread * when the main thread invokes Signal() API on the message queue. */ static void SendLoop( std::unordered_map> sockets, std::shared_ptr queue); }; /*! * \brief SocketReceiver for DGL distributed training. * * SocketReceiver is the communicator implemented by tcp socket. */ class SocketReceiver : public Receiver { public: /*! * \brief Receiver constructor * \param queue_size size of message queue. * \param max_thread_count size of thread pool. 0 for no limit */ SocketReceiver(int64_t queue_size, int max_thread_count) : Receiver(queue_size, max_thread_count) {} /*! * \brief Wait for all the Senders to connect * \param addr Networking address, e.g., 'socket://127.0.0.1:50051', 'mpi://0' * \param num_sender total number of Senders * \return True for success and False for fail * * Wait() is not thread-safe and only one thread can invoke this API. */ bool Wait(const char* addr, int num_sender); /*! * \brief Recv data from Sender. Actually removing data from msg_queue. * \param msg pointer of data message * \param send_id which sender current msg comes from * \return Status code * * (1) The Recv() API is blocking, which will not * return until getting data from message queue. * (2) The Recv() API is thread-safe. * (3) Memory allocated by communicator but will not own it after the function returns. */ STATUS Recv(Message* msg, int* send_id); /*! * \brief Recv data from a specified Sender. Actually removing data from msg_queue. * \param msg pointer of data message * \param send_id sender's ID * \return Status code * * (1) The RecvFrom() API is blocking, which will not * return until getting data from message queue. * (2) The RecvFrom() API is thread-safe. * (3) Memory allocated by communicator but will not own it after the function returns. */ STATUS RecvFrom(Message* msg, int send_id); /*! * \brief Finalize SocketReceiver * * Finalize() is not thread-safe and only one thread can invoke this API. */ void Finalize(); /*! * \brief Communicator type: 'socket' */ inline std::string Type() const { return std::string("socket"); } private: struct RecvContext { int64_t data_size = -1; int64_t received_bytes = 0; char *buffer = nullptr; }; /*! * \brief number of sender */ int num_sender_; /*! * \brief server socket for listening connections */ TCPSocket* server_socket_; /*! * \brief socket for each client connections */ std::vector>> sockets_; /*! * \brief Message queue for each socket connection */ std::unordered_map> msg_queue_; std::unordered_map>::iterator mq_iter_; /*! * \brief Independent thead */ std::vector> threads_; /*! * \brief queue_sem_ semphore to indicate number of messages in multiple * message queues to prevent busy wait of Recv */ runtime::Semaphore queue_sem_; /*! * \brief Recv-loop for each thread * \param sockets client sockets of current thread * \param queue message queues of current thread * * Note that, the RecvLoop will finish its loop-job and exit thread * when the main thread invokes Signal() API on the message queue. */ static void RecvLoop( std::unordered_map> sockets, std::unordered_map> queues, runtime::Semaphore *queue_sem); }; } // namespace network } // namespace dgl #endif // DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_