tcp_socket.cc 5.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*!
 *  Copyright (c) 2019 by Contributors
 * \file tcp_socket.cc
 * \brief TCP socket for DGL distributed training.
 */
#include "tcp_socket.h"

#include <dmlc/logging.h>

#ifndef _WIN32
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <unistd.h>
#endif  // !_WIN32
18
19
#include <string.h>
#include <errno.h>
20
21
22
23
24
25
26
27
28
29
30

namespace dgl {
namespace network {

typedef struct sockaddr_in SAI;
typedef struct sockaddr SA;

TCPSocket::TCPSocket() {
  // init socket
  socket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
  if (socket_ < 0) {
mszarma's avatar
mszarma committed
31
    LOG(FATAL) << "Can't create new socket. Errno=" << errno;
32
33
34
35
36
37
38
39
40
41
42
43
  }
}

TCPSocket::~TCPSocket() {
  Close();
}

bool TCPSocket::Connect(const char * ip, int port) {
  SAI sa_server;
  sa_server.sin_family      = AF_INET;
  sa_server.sin_port        = htons(port);

mszarma's avatar
mszarma committed
44
45
46
47
48
49
50
51
52
  int retval = 0;
  do {  // retry if EINTR failure appears
    if (0 < inet_pton(AF_INET, ip, &sa_server.sin_addr) &&
        0 <= (retval = connect(socket_, reinterpret_cast<SA*>(&sa_server),
                    sizeof(sa_server)))) {
      return true;
    }
  } while (retval == -1 && errno == EINTR);

53
54
55
56
57
58
59
  return false;
}

bool TCPSocket::Bind(const char * ip, int port) {
  SAI sa_server;
  sa_server.sin_family      = AF_INET;
  sa_server.sin_port        = htons(port);
mszarma's avatar
mszarma committed
60
61
62
63
64
65
66
67
68
69
  int retval = 0;
  do {  // retry if EINTR failure appears
    if (0 < inet_pton(AF_INET, ip, &sa_server.sin_addr) &&
        0 <= (retval = bind(socket_, reinterpret_cast<SA*>(&sa_server),
                  sizeof(sa_server)))) {
      return true;
    }
  } while (retval == -1 && errno == EINTR);

  LOG(ERROR) << "Failed bind on " << ip << ":" << port << " ,errno=" << errno;
70
71
72
73
  return false;
}

bool TCPSocket::Listen(int max_connection) {
mszarma's avatar
mszarma committed
74
75
76
77
78
79
80
81
  int retval;
  do {  // retry if EINTR failure appears
    if (0 <= (retval = listen(socket_, max_connection))) {
      return true;
    }
  } while (retval == -1 && errno == EINTR);

  LOG(ERROR) << "Failed listen on socket fd: " << socket_ << " ,errno=" << errno;
82
83
84
85
86
87
88
89
  return false;
}

bool TCPSocket::Accept(TCPSocket * socket, std::string * ip, int * port) {
  int sock_client;
  SAI sa_client;
  socklen_t len = sizeof(sa_client);

mszarma's avatar
mszarma committed
90
91
92
93
  do {  // retry if EINTR failure appears
    sock_client = accept(socket_, reinterpret_cast<SA*>(&sa_client), &len);
  } while (sock_client == -1 && errno == EINTR);

94
  if (sock_client < 0) {
mszarma's avatar
mszarma committed
95
96
    LOG(ERROR) << "Failed accept connection on " << *ip << ":" << *port
               << " ,errno=" << errno << (errno == EAGAIN ? " SO_RCVTIMEO timeout reached" : "");
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    return false;
  }

  char tmp[INET_ADDRSTRLEN];
  const char * ip_client = inet_ntop(AF_INET,
                                     &sa_client.sin_addr,
                                     tmp,
                                     sizeof(tmp));
  CHECK(ip_client != nullptr);
  ip->assign(ip_client);
  *port = ntohs(sa_client.sin_port);
  socket->socket_ = sock_client;

  return true;
}

#ifdef _WIN32
bool TCPSocket::SetBlocking(bool flag) {
  int result;
  u_long argp = flag ? 1 : 0;

  // XXX Non-blocking Windows Sockets apparently has tons of issues:
  // http://www.sockets.com/winsock.htm#Overview_BlockingNonBlocking
  // Since SetBlocking() is not used at all, I'm leaving a default
  // implementation here.  But be warned that this is not fully tested.
  if ((result = ioctlsocket(socket_, FIONBIO, &argp)) != NO_ERROR) {
    LOG(ERROR) << "Failed to set socket status.";
    return false;
  }
  return true;
}
#else   // !_WIN32
bool TCPSocket::SetBlocking(bool flag) {
  int opts;

  if ((opts = fcntl(socket_, F_GETFL)) < 0) {
    LOG(ERROR) << "Failed to get socket status.";
    return false;
  }

  if (flag) {
    opts |= O_NONBLOCK;
  } else {
    opts &= ~O_NONBLOCK;
  }

  if (fcntl(socket_, F_SETFL, opts) < 0) {
    LOG(ERROR) << "Failed to set socket status.";
    return false;
  }

  return true;
}
#endif  // _WIN32

void TCPSocket::SetTimeout(int timeout) {
153
154
155
156
157
158
159
160
161
162
163
  #ifdef _WIN32
    timeout = timeout * 1000;  // WIN API accepts millsec
    setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO,
               reinterpret_cast<char*>(&timeout), sizeof(timeout));
  #else  // !_WIN32
    struct timeval tv;
    tv.tv_sec = timeout;
    tv.tv_usec = 0;
    setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO,
               &tv, sizeof(tv));
  #endif  // _WIN32
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
}

bool TCPSocket::ShutDown(int ways) {
  return 0 == shutdown(socket_, ways);
}

void TCPSocket::Close() {
  if (socket_ >= 0) {
#ifdef _WIN32
    CHECK_EQ(0, closesocket(socket_));
#else   // !_WIN32
    CHECK_EQ(0, close(socket_));
#endif  // _WIN32
    socket_ = -1;
  }
}

int64_t TCPSocket::Send(const char * data, int64_t len_data) {
mszarma's avatar
mszarma committed
182
183
184
185
186
  int64_t number_send;

  do {  // retry if EINTR failure appears
    number_send = send(socket_, data, len_data, 0);
  } while (number_send == -1 && errno == EINTR);
187
188
189
  if (number_send == -1) {
    LOG(ERROR) << "send error: " << strerror(errno);
  }
mszarma's avatar
mszarma committed
190
191

  return number_send;
192
193
194
}

int64_t TCPSocket::Receive(char * buffer, int64_t size_buffer) {
mszarma's avatar
mszarma committed
195
196
197
198
199
  int64_t number_recv;

  do {  // retry if EINTR failure appears
    number_recv = recv(socket_, buffer, size_buffer, 0);
  } while (number_recv == -1 && errno == EINTR);
200
201
202
  if (number_recv == -1) {
    LOG(ERROR) << "recv error: " << strerror(errno);
  }
mszarma's avatar
mszarma committed
203
204

  return number_recv;
205
206
207
208
209
210
211
212
}

int TCPSocket::Socket() const {
  return socket_;
}

}  // namespace network
}  // namespace dgl