#ifndef LIGHTGBM_NETWORK_SOCKET_WRAPPER_HPP_ #define LIGHTGBM_NETWORK_SOCKET_WRAPPER_HPP_ #ifdef USE_SOCKET #if defined(_WIN32) #include #include #include #else #include #include #include #include #include #include #include #include #include #include #include #endif #include #include #include #include #ifdef _MSC_VER #pragma comment(lib, "Ws2_32.lib") #pragma comment(lib, "IPHLPAPI.lib") #endif namespace LightGBM { #ifndef _WIN32 typedef int SOCKET; const int INVALID_SOCKET = -1; #define SOCKET_ERROR -1 #endif #define MALLOC(x) HeapAlloc(GetProcessHeap(), 0, (x)) #define FREE(x) HeapFree(GetProcessHeap(), 0, (x)) namespace SocketConfig { const int kSocketBufferSize = 10 * 1024 * 1024; const int kMaxReceiveSize = 2 * 1024 * 1024; const bool kNoDelay = true; } class TcpSocket { public: TcpSocket() { sockfd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (sockfd_ == INVALID_SOCKET) { Log::Fatal("Socket construct error\n"); return; } ConfigSocket(); } explicit TcpSocket(SOCKET socket) { sockfd_ = socket; if (sockfd_ == INVALID_SOCKET) { Log::Fatal("Passed socket error\n"); return; } ConfigSocket(); } TcpSocket(const TcpSocket &object) { sockfd_ = object.sockfd_; ConfigSocket(); } ~TcpSocket() { } inline void SetTimeout(int timeout) { setsockopt(sockfd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); } inline void ConfigSocket() { if (sockfd_ == INVALID_SOCKET) { return; } setsockopt(sockfd_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast(&SocketConfig::kSocketBufferSize), sizeof(SocketConfig::kSocketBufferSize)); setsockopt(sockfd_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast(&SocketConfig::kSocketBufferSize), sizeof(SocketConfig::kSocketBufferSize)); setsockopt(sockfd_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&SocketConfig::kNoDelay), sizeof(SocketConfig::kNoDelay)); } inline static void Startup() { #if defined(_WIN32) WSADATA wsa_data; if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { Log::Fatal("Socket error: WSAStart up error\n"); } if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { WSACleanup(); Log::Fatal("Socket error: Winsock.dll version error\n"); } #else #endif } inline static void Finalize() { #if defined(_WIN32) WSACleanup(); #endif } inline static int GetLastError() { #if defined(_WIN32) return WSAGetLastError(); #else return errno; #endif } #if defined(_WIN32) inline static std::unordered_set GetLocalIpList() { std::unordered_set ip_list; char buffer[512]; // get hostName if (gethostname(buffer, sizeof(buffer)) == SOCKET_ERROR) { Log::Fatal("Error code: %d, when getting local host name.\n", WSAGetLastError()); } // push local ip PIP_ADAPTER_INFO pAdapterInfo; PIP_ADAPTER_INFO pAdapter = NULL; DWORD dwRetVal = 0; ULONG ulOutBufLen = sizeof(IP_ADAPTER_INFO); pAdapterInfo = (IP_ADAPTER_INFO *)MALLOC(sizeof(IP_ADAPTER_INFO)); if (pAdapterInfo == NULL) { Log::Fatal("GetAdaptersinfo error: allocating memory \n"); } // Make an initial call to GetAdaptersInfo to get // the necessary size into the ulOutBufLen variable if (GetAdaptersInfo(pAdapterInfo, &ulOutBufLen) == ERROR_BUFFER_OVERFLOW) { FREE(pAdapterInfo); pAdapterInfo = (IP_ADAPTER_INFO *)MALLOC(ulOutBufLen); if (pAdapterInfo == NULL) { Log::Fatal("GetAdaptersinfo error: allocating memory \n"); } } if ((dwRetVal = GetAdaptersInfo(pAdapterInfo, &ulOutBufLen)) == NO_ERROR) { pAdapter = pAdapterInfo; while (pAdapter) { ip_list.insert(pAdapter->IpAddressList.IpAddress.String); pAdapter = pAdapter->Next; } } else { Log::Error("GetAdaptersinfo error: code %d \n", dwRetVal); } if (pAdapterInfo) FREE(pAdapterInfo); return ip_list; } #else // see in http://stackoverflow.com/questions/212528/get-the-ip-address-of-the-machine inline static std::unordered_set GetLocalIpList() { std::unordered_set ip_list; struct ifaddrs * ifAddrStruct = NULL; struct ifaddrs * ifa = NULL; void * tmpAddrPtr = NULL; getifaddrs(&ifAddrStruct); for (ifa = ifAddrStruct; ifa != NULL; ifa = ifa->ifa_next) { if (!ifa->ifa_addr) { continue; } if (ifa->ifa_addr->sa_family == AF_INET) { tmpAddrPtr = &((struct sockaddr_in *)ifa->ifa_addr)->sin_addr; char addressBuffer[INET_ADDRSTRLEN]; inet_ntop(AF_INET, tmpAddrPtr, addressBuffer, INET_ADDRSTRLEN); ip_list.insert(std::string(addressBuffer)); } } if (ifAddrStruct != NULL) freeifaddrs(ifAddrStruct); return ip_list; } #endif inline static sockaddr_in GetAddress(const char* url, int port) { sockaddr_in addr = sockaddr_in(); std::memset(&addr, 0, sizeof(sockaddr_in)); inet_pton(AF_INET, url, &addr.sin_addr); addr.sin_family = AF_INET; addr.sin_port = htons(static_cast(port)); return addr; } inline bool Bind(int port) { sockaddr_in local_addr = GetAddress("0.0.0.0", port); if (bind(sockfd_, reinterpret_cast(&local_addr), sizeof(sockaddr_in)) == 0) { return true; } return false; } inline bool Connect(const char *url, int port) { sockaddr_in server_addr = GetAddress(url, port); if (connect(sockfd_, reinterpret_cast(&server_addr), sizeof(sockaddr_in)) == 0) { return true; } return false; } inline void Listen(int backlog = 128) { listen(sockfd_, backlog); } inline TcpSocket Accept() { SOCKET newfd = accept(sockfd_, NULL, NULL); if (newfd == INVALID_SOCKET) { Log::Fatal("Socket accept error, code: %d", GetLastError()); } return TcpSocket(newfd); } inline int Send(const char *buf_, int len, int flag = 0) { int cur_cnt = send(sockfd_, buf_, len, flag); if (cur_cnt == SOCKET_ERROR) { Log::Fatal("Socket send error, code: %d", GetLastError()); } return cur_cnt; } inline int Recv(char *buf_, int len, int flags = 0) { int cur_cnt = recv(sockfd_, buf_ , len , flags); if (cur_cnt == SOCKET_ERROR) { Log::Fatal("Socket recv error, code: %d", GetLastError()); } return cur_cnt; } inline bool IsClosed() { return sockfd_ == INVALID_SOCKET; } inline void Close() { if (!IsClosed()) { #if defined(_WIN32) closesocket(sockfd_); #else close(sockfd_); #endif sockfd_ = INVALID_SOCKET; } } private: SOCKET sockfd_; }; } // namespace LightGBM #endif // USE_SOCKET #endif // LightGBM_NETWORK_SOCKET_WRAPPER_HPP_