socket_communicator.cc 8.94 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2019 by Contributors
 * \file communicator.cc
 * \brief SocketCommunicator for DGL distributed training.
 */
#include <dmlc/logging.h>

8
9
10
11
#include <string.h>
#include <stdlib.h>
#include <time.h>

12
13
14
15
16
17
18
19
20
21
22
23
24
#include "socket_communicator.h"
#include "../../c_api_common.h"

#ifdef _WIN32
#include <windows.h>
#else   // !_WIN32
#include <unistd.h>
#endif  // _WIN32

namespace dgl {
namespace network {


25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
/////////////////////////////////////// SocketSender ///////////////////////////////////////////


void SocketSender::AddReceiver(const char* addr, int recv_id) {
  CHECK_NOTNULL(addr);
  if (recv_id < 0) {
    LOG(FATAL) << "recv_id cannot be a negative number.";
  }
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
  if (substring[0] != "socket:" || substring.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  IPAddr address;
  address.ip = ip_and_port[0];
  address.port = std::stoi(ip_and_port[1]);
  receiver_addrs_[recv_id] = address;
  msg_queue_[recv_id] =  std::make_shared<MessageQueue>(queue_size_);
54
55
}

56
57
bool SocketSender::Connect() {
  // Create N sockets for Receiver
58
  for (const auto& r : receiver_addrs_) {
59
    int ID = r.first;
60
61
    sockets_[ID] = std::make_shared<TCPSocket>();
    TCPSocket* client_socket = sockets_[ID].get();
62
63
    bool bo = false;
    int try_count = 0;
64
65
    const char* ip = r.second.ip.c_str();
    int port = r.second.port;
66
    while (bo == false && try_count < kMaxTryCount) {
67
      if (client_socket->Connect(ip, port)) {
68
69
70
71
72
        bo = true;
      } else {
        LOG(ERROR) << "Cannot connect to Receiver: " << ip << ":" << port
                   << ", try again ...";
        try_count++;
73
#ifdef _WIN32
74
        Sleep(1);
75
#else   // !_WIN32
76
        sleep(1);
77
#endif  // _WIN32
78
79
80
81
82
      }
    }
    if (bo == false) {
      return bo;
    }
83
84
85
86
87
    // Create a new thread for this socket connection
    threads_[ID] = std::make_shared<std::thread>(
      SendLoop,
      client_socket,
      msg_queue_[ID].get());
88
89
90
91
  }
  return true;
}

92
93
94
95
96
97
98
STATUS SocketSender::Send(Message msg, int recv_id) {
  CHECK_NOTNULL(msg.data);
  CHECK_GT(msg.size, 0);
  CHECK_GE(recv_id, 0);
  // Add data message to message queue
  STATUS code = msg_queue_[recv_id]->Add(msg);
  return code;
99
100
101
}

void SocketSender::Finalize() {
102
103
104
105
106
107
108
109
110
  // Send a signal to tell the msg_queue to finish its job
  for (auto& mq : msg_queue_) {
    // wait until queue is empty
    while (mq.second->Empty() == false) {
#ifdef _WIN32
        // just loop
#else   // !_WIN32
        usleep(1000);
#endif  // _WIN32
111
    }
112
113
114
115
116
117
118
119
120
121
    int ID = mq.first;
    mq.second->SignalFinished(ID);
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
    thread.second->join();
  }
  // Clear all sockets
  for (auto& socket : sockets_) {
    socket.second->Close();
122
  }
123
124
}

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
void SocketSender::SendLoop(TCPSocket* socket, MessageQueue* queue) {
  CHECK_NOTNULL(socket);
  CHECK_NOTNULL(queue);
  bool exit = false;
  while (!exit) {
    Message msg;
    STATUS code = queue->Remove(&msg);
    if (code == QUEUE_CLOSE) {
      msg.size = 0;  // send an end-signal to receiver
      exit = true;
    }
    // First send the size
    // If exit == true, we will send zero size to reciever
    int64_t sent_bytes = 0;
    while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
      int64_t max_len = sizeof(int64_t) - sent_bytes;
      int64_t tmp = socket->Send(
        reinterpret_cast<char*>(&msg.size)+sent_bytes,
        max_len);
      CHECK_NE(tmp, -1);
      sent_bytes += tmp;
    }
    // Then send the data
    sent_bytes = 0;
    while (sent_bytes < msg.size) {
      int64_t max_len = msg.size - sent_bytes;
      int64_t tmp = socket->Send(msg.data+sent_bytes, max_len);
      CHECK_NE(tmp, -1);
      sent_bytes += tmp;
    }
    // delete msg
    if (msg.deallocator != nullptr) {
      msg.deallocator(&msg);
    }
  }
160
161
}

162
/////////////////////////////////////// SocketReceiver ///////////////////////////////////////////
163

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
bool SocketReceiver::Wait(const char* addr, int num_sender) {
  CHECK_NOTNULL(addr);
  CHECK_GT(num_sender, 0);
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
  if (substring[0] != "socket:" || substring.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  std::string ip = ip_and_port[0];
  int port = stoi(ip_and_port[1]);
  // Initialize message queue for each connection
186
  num_sender_ = num_sender;
187
188
189
190
191
192
  for (int i = 0; i < num_sender_; ++i) {
    msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
  }
  // Initialize socket and socket-thread
  server_socket_ = new TCPSocket();
  server_socket_->SetTimeout(kTimeOut * 60 * 1000);  // millsec
193
  // Bind socket
194
  if (server_socket_->Bind(ip.c_str(), port) == false) {
195
    LOG(FATAL) << "Cannot bind to " << ip << ":" << port;
196
197
  }
  // Listen
198
  if (server_socket_->Listen(kMaxConnection) == false) {
199
    LOG(FATAL) << "Cannot listen on " << ip << ":" << port;
200
201
202
203
  }
  // Accept all sender sockets
  std::string accept_ip;
  int accept_port;
204
205
206
207
  for (int i = 0; i < num_sender_; ++i) {
    sockets_[i] = std::make_shared<TCPSocket>();
    if (server_socket_->Accept(sockets_[i].get(), &accept_ip, &accept_port) == false) {
      LOG(WARNING) << "Error on accept socket.";
208
209
      return false;
    }
210
    // create new thread for each socket
211
212
213
214
    threads_[i] = std::make_shared<std::thread>(
      RecvLoop,
      sockets_[i].get(),
      msg_queue_[i].get());
215
216
217
218
219
  }

  return true;
}

220
221
STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
  // loop until get a message
222
  for (;;) {
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    for (auto& mq : msg_queue_) {
      *send_id = mq.first;
      // We use non-block remove here
      STATUS code = msg_queue_[*send_id]->Remove(msg, false);
      if (code == QUEUE_EMPTY) {
        continue;  // jump to the next queue
      } else {
        return code;
      }
    }
  }
}

STATUS SocketReceiver::RecvFrom(Message* msg, int send_id) {
  // Get message from specified message queue
  STATUS code = msg_queue_[send_id]->Remove(msg);
  return code;
}

void SocketReceiver::Finalize() {
  // Send a signal to tell the message queue to finish its job
  for (auto& mq : msg_queue_) {
    // wait until queue is empty
    while (mq.second->Empty() == false) {
#ifdef _WIN32
        // just loop
#else   // !_WIN32
        usleep(1000);
#endif  // _WIN32
    }
    int ID = mq.first;
    mq.second->SignalFinished(ID);
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
    thread.second->join();
  }
  // Clear all sockets
  for (auto& socket : sockets_) {
    socket.second->Close();
  }
}

void SocketReceiver::RecvLoop(TCPSocket* socket, MessageQueue* queue) {
  CHECK_NOTNULL(socket);
  CHECK_NOTNULL(queue);
  for (;;) {
    // If main thread had finished its job
    if (queue->EmptyAndNoMoreAdd()) {
      return;  // exit loop thread
    }
274
275
276
    // First recv the size
    int64_t received_bytes = 0;
    int64_t data_size = 0;
Da Zheng's avatar
Da Zheng committed
277
    while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) {
278
279
280
281
      int64_t max_len = sizeof(int64_t) - received_bytes;
      int64_t tmp = socket->Receive(
        reinterpret_cast<char*>(&data_size)+received_bytes,
        max_len);
282
      CHECK_NE(tmp, -1);
283
284
      received_bytes += tmp;
    }
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    if (data_size < 0) {
      LOG(FATAL) << "Recv data error (data_size: " << data_size << ")";
    } else if (data_size == 0) {
      // This is an end-signal sent by client
      return;
    } else {
      char* buffer = nullptr;
      try {
        buffer = new char[data_size];
      } catch(const std::bad_alloc&) {
        LOG(FATAL) << "Cannot allocate enough memory for message, "
                   << "(message size: " << data_size << ")";
      }
      received_bytes = 0;
      while (received_bytes < data_size) {
        int64_t max_len = data_size - received_bytes;
        int64_t tmp = socket->Receive(buffer+received_bytes, max_len);
        CHECK_NE(tmp, -1);
        received_bytes += tmp;
      }
      Message msg;
      msg.data = buffer;
      msg.size = data_size;
      msg.deallocator = DefaultMessageDeleter;
      queue->Add(msg);
310
311
312
313
314
315
    }
  }
}

}  // namespace network
}  // namespace dgl