/*! * Copyright (c) 2019 by Contributors * \file communicator.cc * \brief SocketCommunicator for DGL distributed training. */ #include #include #include #include #include #include "socket_communicator.h" #include "../../c_api_common.h" #include "socket_pool.h" #ifdef _WIN32 #include #else // !_WIN32 #include #endif // _WIN32 namespace dgl { namespace network { /////////////////////////////////////// SocketSender /////////////////////////////////////////// bool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) { if (recv_id < 0) { LOG(FATAL) << "recv_id cannot be a negative number."; } std::vector substring; std::vector ip_and_port; SplitStringUsing(addr, "//", &substring); // Check address format if (substring[0] != "tcp:" || substring.size() != 2) { LOG(FATAL) << "Incorrect address format:" << addr << " Please provide right address format, " << "e.g, 'tcp://127.0.0.1:50051'. "; } // Get IP and port SplitStringUsing(substring[1], ":", &ip_and_port); if (ip_and_port.size() != 2) { LOG(FATAL) << "Incorrect address format:" << addr << " Please provide right address format, " << "e.g, 'tcp://127.0.0.1:50051'. "; } IPAddr address; address.ip = ip_and_port[0]; address.port = std::stoi(ip_and_port[1]); receiver_addrs_[recv_id] = address; return true; } bool SocketSender::ConnectReceiverFinalize(const int max_try_times) { // Create N sockets for Receiver int receiver_count = static_cast(receiver_addrs_.size()); if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) { max_thread_count_ = receiver_count; } sockets_.resize(max_thread_count_); for (const auto& r : receiver_addrs_) { int receiver_id = r.first; int thread_id = receiver_id % max_thread_count_; sockets_[thread_id][receiver_id] = std::make_shared(); TCPSocket* client_socket = sockets_[thread_id][receiver_id].get(); bool bo = false; int try_count = 0; const char* ip = r.second.ip.c_str(); int port = r.second.port; while (bo == false && try_count < max_try_times) { if (client_socket->Connect(ip, port)) { bo = true; } else { if (try_count % 200 == 0 && try_count != 0) { // every 600 seconds show this message LOG(INFO) << "Trying to connect receiver: " << ip << ":" << port; } try_count++; std::this_thread::sleep_for(std::chrono::seconds(3)); } } if (bo == false) { return bo; } } for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) { msg_queue_.push_back(std::make_shared(queue_size_)); // Create a new thread for this socket connection threads_.push_back(std::make_shared( SendLoop, sockets_[thread_id], msg_queue_[thread_id])); } return true; } void SocketSender::Send(const rpc::RPCMessage& msg, int recv_id) { std::shared_ptr zerocopy_blob(new std::string()); StreamWithBuffer zc_write_strm(zerocopy_blob.get(), true); zc_write_strm.Write(msg); int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size(); zerocopy_blob->append(reinterpret_cast(&nonempty_ndarray_count), sizeof(int32_t)); Message rpc_meta_msg; rpc_meta_msg.data = const_cast(zerocopy_blob->data()); rpc_meta_msg.size = zerocopy_blob->size(); rpc_meta_msg.deallocator = [zerocopy_blob](Message*) {}; CHECK_EQ(Send( rpc_meta_msg, recv_id), ADD_SUCCESS); // send real ndarray data for (auto ptr : zc_write_strm.buffer_list()) { Message ndarray_data_msg; ndarray_data_msg.data = reinterpret_cast(ptr.data); if (ptr.size == 0) { LOG(FATAL) << "Cannot send a empty NDArray."; } ndarray_data_msg.size = ptr.size; NDArray tensor = ptr.tensor; ndarray_data_msg.deallocator = [tensor](Message*) {}; CHECK_EQ(Send( ndarray_data_msg, recv_id), ADD_SUCCESS); } } STATUS SocketSender::Send(Message msg, int recv_id) { CHECK_NOTNULL(msg.data); CHECK_GT(msg.size, 0); CHECK_GE(recv_id, 0); msg.receiver_id = recv_id; // Add data message to message queue STATUS code = msg_queue_[recv_id % max_thread_count_]->Add(msg); return code; } void SocketSender::Finalize() { // Send a signal to tell the msg_queue to finish its job for (int i = 0; i < max_thread_count_; ++i) { // wait until queue is empty auto& mq = msg_queue_[i]; while (mq->Empty() == false) { std::this_thread::sleep_for(std::chrono::seconds(1)); } // All queues have only one producer, which is main thread, so // the producerID argument here should be zero. mq->SignalFinished(0); } // Block main thread until all socket-threads finish their jobs for (auto& thread : threads_) { thread->join(); } // Clear all sockets for (auto& group_sockets_ : sockets_) { for (auto &socket : group_sockets_) { socket.second->Close(); } } } void SendCore(Message msg, TCPSocket* socket) { // First send the size // If exit == true, we will send zero size to reciever int64_t sent_bytes = 0; while (static_cast(sent_bytes) < sizeof(int64_t)) { int64_t max_len = sizeof(int64_t) - sent_bytes; int64_t tmp = socket->Send( reinterpret_cast(&msg.size) + sent_bytes, max_len); CHECK_NE(tmp, -1); sent_bytes += tmp; } // Then send the data sent_bytes = 0; while (sent_bytes < msg.size) { int64_t max_len = msg.size - sent_bytes; int64_t tmp = socket->Send(msg.data+sent_bytes, max_len); CHECK_NE(tmp, -1); sent_bytes += tmp; } // delete msg if (msg.deallocator != nullptr) { msg.deallocator(&msg); } } void SocketSender::SendLoop( std::unordered_map> sockets, std::shared_ptr queue) { for (;;) { Message msg; STATUS code = queue->Remove(&msg); if (code == QUEUE_CLOSE) { msg.size = 0; // send an end-signal to receiver for (auto& socket : sockets) { SendCore(msg, socket.second.get()); } break; } SendCore(msg, sockets[msg.receiver_id].get()); } } /////////////////////////////////////// SocketReceiver /////////////////////////////////////////// bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking) { CHECK_GT(num_sender, 0); CHECK_EQ(blocking, true); std::vector substring; std::vector ip_and_port; SplitStringUsing(addr, "//", &substring); // Check address format if (substring[0] != "tcp:" || substring.size() != 2) { LOG(FATAL) << "Incorrect address format:" << addr << " Please provide right address format, " << "e.g, 'tcp://127.0.0.1:50051'. "; } // Get IP and port SplitStringUsing(substring[1], ":", &ip_and_port); if (ip_and_port.size() != 2) { LOG(FATAL) << "Incorrect address format:" << addr << " Please provide right address format, " << "e.g, 'tcp://127.0.0.1:50051'. "; } std::string ip = ip_and_port[0]; int port = stoi(ip_and_port[1]); // Initialize message queue for each connection num_sender_ = num_sender; #ifdef USE_EPOLL if (max_thread_count_ == 0 || max_thread_count_ > num_sender_) { max_thread_count_ = num_sender_; } #else max_thread_count_ = num_sender_; #endif // Initialize socket and socket-thread server_socket_ = new TCPSocket(); // Bind socket if (server_socket_->Bind(ip.c_str(), port) == false) { LOG(FATAL) << "Cannot bind to " << ip << ":" << port; } // Listen if (server_socket_->Listen(kMaxConnection) == false) { LOG(FATAL) << "Cannot listen on " << ip << ":" << port; } // Accept all sender sockets std::string accept_ip; int accept_port; sockets_.resize(max_thread_count_); for (int i = 0; i < num_sender_; ++i) { int thread_id = i % max_thread_count_; auto socket = std::make_shared(); sockets_[thread_id][i] = socket; msg_queue_[i] = std::make_shared(queue_size_); if (server_socket_->Accept(socket.get(), &accept_ip, &accept_port) == false) { LOG(WARNING) << "Error on accept socket."; return false; } } mq_iter_ = msg_queue_.begin(); for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) { // create new thread for each socket threads_.push_back(std::make_shared( RecvLoop, sockets_[thread_id], msg_queue_, &queue_sem_)); } return true; } rpc::RPCStatus SocketReceiver::Recv(rpc::RPCMessage* msg, int timeout) { Message rpc_meta_msg; int send_id; auto status = Recv(&rpc_meta_msg, &send_id, timeout); if (status == QUEUE_EMPTY) { DLOG(WARNING) << "Timed out when trying to receive rpc meta data after " << timeout << " milliseconds."; return rpc::kRPCTimeOut; } CHECK_EQ(status, REMOVE_SUCCESS); char* count_ptr = rpc_meta_msg.data+rpc_meta_msg.size-sizeof(int32_t); int32_t nonempty_ndarray_count = *(reinterpret_cast(count_ptr)); // Recv real ndarray data std::vector buffer_list(nonempty_ndarray_count); for (int i = 0; i < nonempty_ndarray_count; ++i) { Message ndarray_data_msg; status = RecvFrom(&ndarray_data_msg, send_id, timeout); if (status == QUEUE_EMPTY) { // As we cannot handle this timeout for now, let's treat it as fatal // error. LOG(FATAL) << "Timed out when trying to receive rpc ndarray data after " << timeout << " milliseconds."; return rpc::kRPCTimeOut; } CHECK_EQ(status, REMOVE_SUCCESS); buffer_list[i] = ndarray_data_msg.data; } StreamWithBuffer zc_read_strm(rpc_meta_msg.data, rpc_meta_msg.size-sizeof(int32_t), buffer_list); zc_read_strm.Read(msg); rpc_meta_msg.deallocator(&rpc_meta_msg); return rpc::kRPCSuccess; } STATUS SocketReceiver::Recv(Message* msg, int* send_id, int timeout) { // queue_sem_ is a semaphore indicating how many elements in multiple // message queues. // When calling queue_sem_.Wait(), this Recv will be suspended until // queue_sem_ > 0 or specified timeout expires, decrease queue_sem_ by 1, // then start to fetch a message. if (!queue_sem_.TimedWait(timeout)) { return QUEUE_EMPTY; } for (;;) { for (; mq_iter_ != msg_queue_.end(); ++mq_iter_) { STATUS code = mq_iter_->second->Remove(msg, false); if (code == QUEUE_EMPTY) { continue; // jump to the next queue } else { *send_id = mq_iter_->first; ++mq_iter_; return code; } } mq_iter_ = msg_queue_.begin(); } LOG(ERROR) << "Failed to remove message from queue due to unexpected queue status."; return QUEUE_CLOSE; } STATUS SocketReceiver::RecvFrom(Message* msg, int send_id, int timeout) { // Get message from specified message queue if (!queue_sem_.TimedWait(timeout)) { return QUEUE_EMPTY; } STATUS code = msg_queue_[send_id]->Remove(msg); return code; } void SocketReceiver::Finalize() { // Send a signal to tell the message queue to finish its job for (auto& mq : msg_queue_) { // wait until queue is empty while (mq.second->Empty() == false) { std::this_thread::sleep_for(std::chrono::seconds(1)); } mq.second->SignalFinished(mq.first); } // Block main thread until all socket-threads finish their jobs for (auto& thread : threads_) { thread->join(); } // Clear all sockets for (auto& group_sockets : sockets_) { for (auto& socket : group_sockets) { socket.second->Close(); } } server_socket_->Close(); delete server_socket_; } int64_t RecvDataSize(TCPSocket* socket) { int64_t received_bytes = 0; int64_t data_size = 0; while (static_cast(received_bytes) < sizeof(int64_t)) { int64_t max_len = sizeof(int64_t) - received_bytes; int64_t tmp = socket->Receive( reinterpret_cast(&data_size) + received_bytes, max_len); if (tmp == -1) { if (received_bytes > 0) { // We want to finish reading full data_size continue; } return -1; } received_bytes += tmp; } return data_size; } void RecvData(TCPSocket* socket, char* buffer, const int64_t &data_size, int64_t *received_bytes) { while (*received_bytes < data_size) { int64_t max_len = data_size - *received_bytes; int64_t tmp = socket->Receive(buffer + *received_bytes, max_len); if (tmp == -1) { // Socket not ready, no more data to read return; } *received_bytes += tmp; } } void SocketReceiver::RecvLoop( std::unordered_map> sockets, std::unordered_map> queues, runtime::Semaphore *queue_sem) { std::unordered_map> recv_contexts; SocketPool socket_pool; for (auto& socket : sockets) { auto &sender_id = socket.first; socket_pool.AddSocket(socket.second, sender_id); recv_contexts[sender_id] = std::unique_ptr(new RecvContext()); } // Main loop to receive messages for (;;) { int sender_id; // Get active socket using epoll std::shared_ptr socket = socket_pool.GetActiveSocket(&sender_id); if (queues[sender_id]->EmptyAndNoMoreAdd()) { // This sender has already stopped if (socket_pool.RemoveSocket(socket) == 0) { return; } continue; } // Nonblocking socket might be interrupted at any point. So we need to // store the partially received data std::unique_ptr &ctx = recv_contexts[sender_id]; int64_t &data_size = ctx->data_size; int64_t &received_bytes = ctx->received_bytes; char*& buffer = ctx->buffer; if (data_size == -1) { // This is a new message, so receive the data size first data_size = RecvDataSize(socket.get()); if (data_size > 0) { try { buffer = new char[data_size]; } catch(const std::bad_alloc&) { LOG(FATAL) << "Cannot allocate enough memory for message, " << "(message size: " << data_size << ")"; } received_bytes = 0; } else if (data_size == 0) { // Received stop signal if (socket_pool.RemoveSocket(socket) == 0) { return; } } } RecvData(socket.get(), buffer, data_size, &received_bytes); if (received_bytes >= data_size) { // Full data received, create Message and push to queue Message msg; msg.data = buffer; msg.size = data_size; msg.deallocator = DefaultMessageDeleter; queues[sender_id]->Add(msg); // Reset recv context data_size = -1; // Signal queue semaphore queue_sem->Post(); } } } } // namespace network } // namespace dgl