tcp_socket.cc 4.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
/*!
 *  Copyright (c) 2019 by Contributors
 * \file tcp_socket.cc
 * \brief TCP socket for DGL distributed training.
 */
#include "tcp_socket.h"

#include <dmlc/logging.h>

#ifndef _WIN32
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <unistd.h>
#endif  // !_WIN32

namespace dgl {
namespace network {

typedef struct sockaddr_in SAI;
typedef struct sockaddr SA;

TCPSocket::TCPSocket() {
  // init socket
  socket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
  if (socket_ < 0) {
    LOG(FATAL) << "Can't create new socket.";
  }
}

TCPSocket::~TCPSocket() {
  Close();
}

bool TCPSocket::Connect(const char * ip, int port) {
  SAI sa_server;
  sa_server.sin_family      = AF_INET;
  sa_server.sin_port        = htons(port);

  if (0 < inet_pton(AF_INET, ip, &sa_server.sin_addr) &&
      0 <= connect(socket_, reinterpret_cast<SA*>(&sa_server),
                   sizeof(sa_server))) {
    return true;
  }

  LOG(ERROR) << "Failed connect to " << ip << ":" << port;
  return false;
}

bool TCPSocket::Bind(const char * ip, int port) {
  SAI sa_server;
  sa_server.sin_family      = AF_INET;
  sa_server.sin_port        = htons(port);

  if (0 < inet_pton(AF_INET, ip, &sa_server.sin_addr) &&
      0 <= bind(socket_, reinterpret_cast<SA*>(&sa_server),
                sizeof(sa_server))) {
    return true;
  }

  LOG(ERROR) << "Failed bind on " << ip << ":" << port;
  return false;
}

bool TCPSocket::Listen(int max_connection) {
  if (0 <= listen(socket_, max_connection)) {
    return true;
  }

  LOG(ERROR) << "Failed listen on socket fd: " << socket_;
  return false;
}

bool TCPSocket::Accept(TCPSocket * socket, std::string * ip, int * port) {
  int sock_client;
  SAI sa_client;
  socklen_t len = sizeof(sa_client);

  sock_client = accept(socket_, reinterpret_cast<SA*>(&sa_client), &len);
  if (sock_client < 0) {
    LOG(ERROR) << "Failed accept connection on " << *ip << ":" << *port;
    return false;
  }

  char tmp[INET_ADDRSTRLEN];
  const char * ip_client = inet_ntop(AF_INET,
                                     &sa_client.sin_addr,
                                     tmp,
                                     sizeof(tmp));
  CHECK(ip_client != nullptr);
  ip->assign(ip_client);
  *port = ntohs(sa_client.sin_port);
  socket->socket_ = sock_client;

  return true;
}

#ifdef _WIN32
bool TCPSocket::SetBlocking(bool flag) {
  int result;
  u_long argp = flag ? 1 : 0;

  // XXX Non-blocking Windows Sockets apparently has tons of issues:
  // http://www.sockets.com/winsock.htm#Overview_BlockingNonBlocking
  // Since SetBlocking() is not used at all, I'm leaving a default
  // implementation here.  But be warned that this is not fully tested.
  if ((result = ioctlsocket(socket_, FIONBIO, &argp)) != NO_ERROR) {
    LOG(ERROR) << "Failed to set socket status.";
    return false;
  }
  return true;
}
#else   // !_WIN32
bool TCPSocket::SetBlocking(bool flag) {
  int opts;

  if ((opts = fcntl(socket_, F_GETFL)) < 0) {
    LOG(ERROR) << "Failed to get socket status.";
    return false;
  }

  if (flag) {
    opts |= O_NONBLOCK;
  } else {
    opts &= ~O_NONBLOCK;
  }

  if (fcntl(socket_, F_SETFL, opts) < 0) {
    LOG(ERROR) << "Failed to set socket status.";
    return false;
  }

  return true;
}
#endif  // _WIN32

void TCPSocket::SetTimeout(int timeout) {
140
141
142
143
144
145
146
147
148
149
150
  #ifdef _WIN32
    timeout = timeout * 1000;  // WIN API accepts millsec
    setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO,
               reinterpret_cast<char*>(&timeout), sizeof(timeout));
  #else  // !_WIN32
    struct timeval tv;
    tv.tv_sec = timeout;
    tv.tv_usec = 0;
    setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO,
               &tv, sizeof(tv));
  #endif  // _WIN32
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
}

bool TCPSocket::ShutDown(int ways) {
  return 0 == shutdown(socket_, ways);
}

void TCPSocket::Close() {
  if (socket_ >= 0) {
#ifdef _WIN32
    CHECK_EQ(0, closesocket(socket_));
#else   // !_WIN32
    CHECK_EQ(0, close(socket_));
#endif  // _WIN32
    socket_ = -1;
  }
}

int64_t TCPSocket::Send(const char * data, int64_t len_data) {
  return send(socket_, data, len_data, 0);
}

int64_t TCPSocket::Receive(char * buffer, int64_t size_buffer) {
  return recv(socket_, buffer, size_buffer, 0);
}

int TCPSocket::Socket() const {
  return socket_;
}

}  // namespace network
}  // namespace dgl