tcp_socket.cc 5.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*!
 *  Copyright (c) 2019 by Contributors
 * \file tcp_socket.cc
 * \brief TCP socket for DGL distributed training.
 */
#include "tcp_socket.h"

#include <dmlc/logging.h>

#ifndef _WIN32
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <unistd.h>
#endif  // !_WIN32
18
#include <errno.h>
19
#include <string.h>
20
21
22
23
24
25
26
27
28
29
30

namespace dgl {
namespace network {

typedef struct sockaddr_in SAI;
typedef struct sockaddr SA;

TCPSocket::TCPSocket() {
  // init socket
  socket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
  if (socket_ < 0) {
31
    LOG(FATAL) << "Can't create new socket. Error: " << strerror(errno);
32
  }
33
#ifndef _WIN32
34
35
  // This is to make sure the same port can be reused right after the socket is
  // closed.
36
37
  int enable = 1;
  if (setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)) < 0) {
38
39
    LOG(WARNING) << "cannot make the socket reusable. Error: "
                 << strerror(errno);
40
41
  }
#endif  // _WIN32
42
43
}

44
TCPSocket::~TCPSocket() { Close(); }
45

46
bool TCPSocket::Connect(const char *ip, int port) {
47
  SAI sa_server;
48
49
  sa_server.sin_family = AF_INET;
  sa_server.sin_port = htons(port);
50

mszarma's avatar
mszarma committed
51
52
53
  int retval = 0;
  do {  // retry if EINTR failure appears
    if (0 < inet_pton(AF_INET, ip, &sa_server.sin_addr) &&
54
55
56
        0 <= (retval = connect(
                  socket_, reinterpret_cast<SA *>(&sa_server),
                  sizeof(sa_server)))) {
mszarma's avatar
mszarma committed
57
58
59
60
      return true;
    }
  } while (retval == -1 && errno == EINTR);

61
62
63
  return false;
}

64
bool TCPSocket::Bind(const char *ip, int port) {
65
  SAI sa_server;
66
67
  sa_server.sin_family = AF_INET;
  sa_server.sin_port = htons(port);
68
69
70
71
72
73
74
75
76
77
  int ret = 0;
  ret = inet_pton(AF_INET, ip, &sa_server.sin_addr);
  if (ret == 0) {
    LOG(ERROR) << "Invalid IP: " << ip;
    return false;
  } else if (ret < 0) {
    LOG(ERROR) << "Failed to convert [" << ip
               << "] to binary form, error: " << strerror(errno);
    return false;
  }
mszarma's avatar
mszarma committed
78
  do {  // retry if EINTR failure appears
79
80
81
    if (0 <=
        (ret = bind(
             socket_, reinterpret_cast<SA *>(&sa_server), sizeof(sa_server)))) {
mszarma's avatar
mszarma committed
82
83
      return true;
    }
84
  } while (ret == -1 && errno == EINTR);
mszarma's avatar
mszarma committed
85

86
87
  LOG(ERROR) << "Failed bind on " << ip << ":" << port
             << " , error: " << strerror(errno);
88
89
90
91
  return false;
}

bool TCPSocket::Listen(int max_connection) {
mszarma's avatar
mszarma committed
92
93
94
95
96
97
98
  int retval;
  do {  // retry if EINTR failure appears
    if (0 <= (retval = listen(socket_, max_connection))) {
      return true;
    }
  } while (retval == -1 && errno == EINTR);

99
100
  LOG(ERROR) << "Failed listen on socket fd: " << socket_
             << " , error: " << strerror(errno);
101
102
103
  return false;
}

104
bool TCPSocket::Accept(TCPSocket *socket, std::string *ip, int *port) {
105
106
107
108
  int sock_client;
  SAI sa_client;
  socklen_t len = sizeof(sa_client);

mszarma's avatar
mszarma committed
109
  do {  // retry if EINTR failure appears
110
    sock_client = accept(socket_, reinterpret_cast<SA *>(&sa_client), &len);
mszarma's avatar
mszarma committed
111
112
  } while (sock_client == -1 && errno == EINTR);

113
  if (sock_client < 0) {
mszarma's avatar
mszarma committed
114
    LOG(ERROR) << "Failed accept connection on " << *ip << ":" << *port
115
116
               << ", error: " << strerror(errno)
               << (errno == EAGAIN ? " SO_RCVTIMEO timeout reached" : "");
117
118
119
120
    return false;
  }

  char tmp[INET_ADDRSTRLEN];
121
122
  const char *ip_client =
      inet_ntop(AF_INET, &sa_client.sin_addr, tmp, sizeof(tmp));
123
124
125
126
127
128
129
130
131
  CHECK(ip_client != nullptr);
  ip->assign(ip_client);
  *port = ntohs(sa_client.sin_port);
  socket->socket_ = sock_client;

  return true;
}

#ifdef _WIN32
132
bool TCPSocket::SetNonBlocking(bool flag) {
133
134
135
136
137
138
139
140
141
142
143
144
145
146
  int result;
  u_long argp = flag ? 1 : 0;

  // XXX Non-blocking Windows Sockets apparently has tons of issues:
  // http://www.sockets.com/winsock.htm#Overview_BlockingNonBlocking
  // Since SetBlocking() is not used at all, I'm leaving a default
  // implementation here.  But be warned that this is not fully tested.
  if ((result = ioctlsocket(socket_, FIONBIO, &argp)) != NO_ERROR) {
    LOG(ERROR) << "Failed to set socket status.";
    return false;
  }
  return true;
}
#else   // !_WIN32
147
bool TCPSocket::SetNonBlocking(bool flag) {
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
  int opts;

  if ((opts = fcntl(socket_, F_GETFL)) < 0) {
    LOG(ERROR) << "Failed to get socket status.";
    return false;
  }

  if (flag) {
    opts |= O_NONBLOCK;
  } else {
    opts &= ~O_NONBLOCK;
  }

  if (fcntl(socket_, F_SETFL, opts) < 0) {
    LOG(ERROR) << "Failed to set socket status.";
    return false;
  }

  return true;
}
#endif  // _WIN32

void TCPSocket::SetTimeout(int timeout) {
171
172
173
174
175
176
177
178
179
180
181
#ifdef _WIN32
  timeout = timeout * 1000;  // WIN API accepts millsec
  setsockopt(
      socket_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char *>(&timeout),
      sizeof(timeout));
#else   // !_WIN32
  struct timeval tv;
  tv.tv_sec = timeout;
  tv.tv_usec = 0;
  setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
#endif  // _WIN32
182
183
}

184
bool TCPSocket::ShutDown(int ways) { return 0 == shutdown(socket_, ways); }
185
186
187
188
189
190
191
192
193
194
195
196

void TCPSocket::Close() {
  if (socket_ >= 0) {
#ifdef _WIN32
    CHECK_EQ(0, closesocket(socket_));
#else   // !_WIN32
    CHECK_EQ(0, close(socket_));
#endif  // _WIN32
    socket_ = -1;
  }
}

197
int64_t TCPSocket::Send(const char *data, int64_t len_data) {
mszarma's avatar
mszarma committed
198
199
200
201
202
  int64_t number_send;

  do {  // retry if EINTR failure appears
    number_send = send(socket_, data, len_data, 0);
  } while (number_send == -1 && errno == EINTR);
203
204
205
  if (number_send == -1) {
    LOG(ERROR) << "send error: " << strerror(errno);
  }
mszarma's avatar
mszarma committed
206
207

  return number_send;
208
209
}

210
int64_t TCPSocket::Receive(char *buffer, int64_t size_buffer) {
mszarma's avatar
mszarma committed
211
212
213
214
215
  int64_t number_recv;

  do {  // retry if EINTR failure appears
    number_recv = recv(socket_, buffer, size_buffer, 0);
  } while (number_recv == -1 && errno == EINTR);
216
  if (number_recv == -1 && errno != EAGAIN && errno != EWOULDBLOCK) {
217
218
    LOG(ERROR) << "recv error: " << strerror(errno);
  }
mszarma's avatar
mszarma committed
219
220

  return number_recv;
221
222
}

223
int TCPSocket::Socket() const { return socket_; }
224
225
226

}  // namespace network
}  // namespace dgl