socket_communicator.cc 12.1 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2019 by Contributors
 * \file communicator.cc
 * \brief SocketCommunicator for DGL distributed training.
 */
#include <dmlc/logging.h>

8
9
10
#include <string.h>
#include <stdlib.h>
#include <time.h>
11
#include <memory>
12

13
14
#include "socket_communicator.h"
#include "../../c_api_common.h"
15
#include "socket_pool.h"
16
17
18
19
20
21
22
23
24
25
26

#ifdef _WIN32
#include <windows.h>
#else   // !_WIN32
#include <unistd.h>
#endif  // _WIN32

namespace dgl {
namespace network {


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
/////////////////////////////////////// SocketSender ///////////////////////////////////////////


void SocketSender::AddReceiver(const char* addr, int recv_id) {
  CHECK_NOTNULL(addr);
  if (recv_id < 0) {
    LOG(FATAL) << "recv_id cannot be a negative number.";
  }
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
  if (substring[0] != "socket:" || substring.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  IPAddr address;
  address.ip = ip_and_port[0];
  address.port = std::stoi(ip_and_port[1]);
  receiver_addrs_[recv_id] = address;
55
56
}

57
58
bool SocketSender::Connect() {
  // Create N sockets for Receiver
59
60
61
62
63
  int receiver_count = static_cast<int>(receiver_addrs_.size());
  if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) {
    max_thread_count_ = receiver_count;
  }
  sockets_.resize(max_thread_count_);
64
  for (const auto& r : receiver_addrs_) {
65
66
67
68
    int receiver_id = r.first;
    int thread_id = receiver_id % max_thread_count_;
    sockets_[thread_id][receiver_id] = std::make_shared<TCPSocket>();
    TCPSocket* client_socket = sockets_[thread_id][receiver_id].get();
69
70
    bool bo = false;
    int try_count = 0;
71
72
    const char* ip = r.second.ip.c_str();
    int port = r.second.port;
73
    while (bo == false && try_count < kMaxTryCount) {
74
      if (client_socket->Connect(ip, port)) {
75
76
        bo = true;
      } else {
77
78
        if (try_count % 200 == 0 && try_count != 0) {
          // every 1000 seconds show this message
79
80
          LOG(INFO) << "Try to connect to: " << ip << ":" << port;
        }
81
        try_count++;
82
#ifdef _WIN32
83
        Sleep(5);
84
#else   // !_WIN32
85
        sleep(5);
86
#endif  // _WIN32
87
88
89
90
91
      }
    }
    if (bo == false) {
      return bo;
    }
92
93
94
95
  }

  for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
    msg_queue_.push_back(std::make_shared<MessageQueue>(queue_size_));
96
    // Create a new thread for this socket connection
97
    threads_.push_back(std::make_shared<std::thread>(
98
      SendLoop,
99
100
      sockets_[thread_id],
      msg_queue_[thread_id]));
101
  }
102

103
104
105
  return true;
}

106
107
108
109
STATUS SocketSender::Send(Message msg, int recv_id) {
  CHECK_NOTNULL(msg.data);
  CHECK_GT(msg.size, 0);
  CHECK_GE(recv_id, 0);
110
  msg.receiver_id = recv_id;
111
  // Add data message to message queue
112
  STATUS code = msg_queue_[recv_id % max_thread_count_]->Add(msg);
113
  return code;
114
115
116
}

void SocketSender::Finalize() {
117
  // Send a signal to tell the msg_queue to finish its job
118
  for (int i = 0; i < max_thread_count_; ++i) {
119
    // wait until queue is empty
120
121
    auto& mq = msg_queue_[i];
    while (mq->Empty() == false) {
122
123
124
125
126
#ifdef _WIN32
        // just loop
#else   // !_WIN32
        usleep(1000);
#endif  // _WIN32
127
    }
128
129
130
    // All queues have only one producer, which is main thread, so
    // the producerID argument here should be zero.
    mq->SignalFinished(0);
131
132
133
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
134
    thread->join();
135
136
  }
  // Clear all sockets
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
  for (auto& group_sockets_ : sockets_) {
    for (auto &socket : group_sockets_) {
      socket.second->Close();
    }
  }
}

void SendCore(Message msg, TCPSocket* socket) {
  // First send the size
  // If exit == true, we will send zero size to reciever
  int64_t sent_bytes = 0;
  while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
    int64_t max_len = sizeof(int64_t) - sent_bytes;
    int64_t tmp = socket->Send(
      reinterpret_cast<char*>(&msg.size) + sent_bytes,
      max_len);
    CHECK_NE(tmp, -1);
    sent_bytes += tmp;
  }
  // Then send the data
  sent_bytes = 0;
  while (sent_bytes < msg.size) {
    int64_t max_len = msg.size - sent_bytes;
    int64_t tmp = socket->Send(msg.data+sent_bytes, max_len);
    CHECK_NE(tmp, -1);
    sent_bytes += tmp;
  }
  // delete msg
  if (msg.deallocator != nullptr) {
    msg.deallocator(&msg);
167
  }
168
169
}

170
171
172
173
void SocketSender::SendLoop(
  std::unordered_map<int, std::shared_ptr<TCPSocket>> sockets,
  std::shared_ptr<MessageQueue> queue) {
  for (;;) {
174
175
176
177
    Message msg;
    STATUS code = queue->Remove(&msg);
    if (code == QUEUE_CLOSE) {
      msg.size = 0;  // send an end-signal to receiver
178
179
180
181
      for (auto& socket : sockets) {
        SendCore(msg, socket.second.get());
      }
      break;
182
    }
183
    SendCore(msg, sockets[msg.receiver_id].get());
184
  }
185
186
}

187
/////////////////////////////////////// SocketReceiver ///////////////////////////////////////////
188

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
bool SocketReceiver::Wait(const char* addr, int num_sender) {
  CHECK_NOTNULL(addr);
  CHECK_GT(num_sender, 0);
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
  if (substring[0] != "socket:" || substring.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  std::string ip = ip_and_port[0];
  int port = stoi(ip_and_port[1]);
  // Initialize message queue for each connection
211
  num_sender_ = num_sender;
212
213
214
#ifdef USE_EPOLL
  if (max_thread_count_ == 0 || max_thread_count_ > num_sender_) {
      max_thread_count_ = num_sender_;
215
  }
216
217
218
#else
  max_thread_count_ = num_sender_;
#endif
219
220
  // Initialize socket and socket-thread
  server_socket_ = new TCPSocket();
221
  // Bind socket
222
  if (server_socket_->Bind(ip.c_str(), port) == false) {
223
    LOG(FATAL) << "Cannot bind to " << ip << ":" << port;
224
  }
225

226
  // Listen
227
  if (server_socket_->Listen(kMaxConnection) == false) {
228
    LOG(FATAL) << "Cannot listen on " << ip << ":" << port;
229
230
231
232
  }
  // Accept all sender sockets
  std::string accept_ip;
  int accept_port;
233
  sockets_.resize(max_thread_count_);
234
  for (int i = 0; i < num_sender_; ++i) {
235
236
237
238
239
    int thread_id = i % max_thread_count_;
    auto socket = std::make_shared<TCPSocket>();
    sockets_[thread_id][i] = socket;
    msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
    if (server_socket_->Accept(socket.get(), &accept_ip, &accept_port) == false) {
240
      LOG(WARNING) << "Error on accept socket.";
241
242
      return false;
    }
243
244
245
246
  }
  mq_iter_ = msg_queue_.begin();

  for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
247
    // create new thread for each socket
248
    threads_.push_back(std::make_shared<std::thread>(
249
      RecvLoop,
250
251
252
      sockets_[thread_id],
      msg_queue_,
      &queue_sem_));
253
254
255
256
257
  }

  return true;
}

258
STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
259
260
261
262
263
  // queue_sem_ is a semaphore indicating how many elements in multiple
  // message queues.
  // When calling queue_sem_.Wait(), this Recv will be suspended until
  // queue_sem_ > 0, decrease queue_sem_ by 1, then start to fetch a message.
  queue_sem_.Wait();
264
  for (;;) {
265
266
    for (; mq_iter_ != msg_queue_.end(); ++mq_iter_) {
      STATUS code = mq_iter_->second->Remove(msg, false);
267
268
269
      if (code == QUEUE_EMPTY) {
        continue;  // jump to the next queue
      } else {
270
271
        *send_id = mq_iter_->first;
        ++mq_iter_;
272
273
274
        return code;
      }
    }
275
    mq_iter_ = msg_queue_.begin();
276
277
278
279
280
  }
}

STATUS SocketReceiver::RecvFrom(Message* msg, int send_id) {
  // Get message from specified message queue
281
  queue_sem_.Wait();
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
  STATUS code = msg_queue_[send_id]->Remove(msg);
  return code;
}

void SocketReceiver::Finalize() {
  // Send a signal to tell the message queue to finish its job
  for (auto& mq : msg_queue_) {
    // wait until queue is empty
    while (mq.second->Empty() == false) {
#ifdef _WIN32
        // just loop
#else   // !_WIN32
        usleep(1000);
#endif  // _WIN32
    }
297
    mq.second->SignalFinished(mq.first);
298
299
300
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
301
    thread->join();
302
303
  }
  // Clear all sockets
304
305
306
307
  for (auto& group_sockets : sockets_) {
    for (auto& socket : group_sockets) {
      socket.second->Close();
    }
308
  }
Chao Ma's avatar
Chao Ma committed
309
310
  server_socket_->Close();
  delete server_socket_;
311
312
}

313
314
315
316
317
318
319
320
321
322
323
324
325
326
int64_t RecvDataSize(TCPSocket* socket) {
  int64_t received_bytes = 0;
  int64_t data_size = 0;
  while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) {
    int64_t max_len = sizeof(int64_t) - received_bytes;
    int64_t tmp = socket->Receive(
      reinterpret_cast<char*>(&data_size) + received_bytes,
      max_len);
    if (tmp == -1) {
      if (received_bytes > 0) {
        // We want to finish reading full data_size
        continue;
      }
      return -1;
327
    }
328
329
330
331
332
333
334
335
336
337
338
339
    received_bytes += tmp;
  }
  return data_size;
}

void RecvData(TCPSocket* socket, char* buffer, const int64_t &data_size,
  int64_t *received_bytes) {
  while (*received_bytes < data_size) {
    int64_t max_len = data_size - *received_bytes;
    int64_t tmp = socket->Receive(buffer + *received_bytes, max_len);
    if (tmp == -1) {
      // Socket not ready, no more data to read
340
      return;
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
    }
    *received_bytes += tmp;
  }
}

void SocketReceiver::RecvLoop(
  std::unordered_map<int /* Sender (virtual) ID */,
    std::shared_ptr<TCPSocket>> sockets,
  std::unordered_map<int /* Sender (virtual) ID */,
    std::shared_ptr<MessageQueue>> queues,
  runtime::Semaphore *queue_sem) {
  std::unordered_map<int, std::unique_ptr<RecvContext>> recv_contexts;
  SocketPool socket_pool;
  for (auto& socket : sockets) {
    auto &sender_id = socket.first;
    socket_pool.AddSocket(socket.second, sender_id);
    recv_contexts[sender_id] = std::unique_ptr<RecvContext>(new RecvContext());
  }

  // Main loop to receive messages
  for (;;) {
    int sender_id;
    // Get active socket using epoll
    std::shared_ptr<TCPSocket> socket = socket_pool.GetActiveSocket(&sender_id);
    if (queues[sender_id]->EmptyAndNoMoreAdd()) {
      // This sender has already stopped
      if (socket_pool.RemoveSocket(socket) == 0) {
        return;
369
      }
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
      continue;
    }

    // Nonblocking socket might be interrupted at any point. So we need to
    // store the partially received data
    std::unique_ptr<RecvContext> &ctx = recv_contexts[sender_id];
    int64_t &data_size = ctx->data_size;
    int64_t &received_bytes = ctx->received_bytes;
    char*& buffer = ctx->buffer;

    if (data_size == -1) {
      // This is a new message, so receive the data size first
      data_size = RecvDataSize(socket.get());
      if (data_size > 0) {
        try {
          buffer = new char[data_size];
        } catch(const std::bad_alloc&) {
          LOG(FATAL) << "Cannot allocate enough memory for message, "
                     << "(message size: " << data_size << ")";
        }
        received_bytes = 0;
      } else if (data_size == 0) {
        // Received stop signal
        if (socket_pool.RemoveSocket(socket) == 0) {
          return;
        }
396
      }
397
398
399
400
401
    }

    RecvData(socket.get(), buffer, data_size, &received_bytes);
    if (received_bytes >= data_size) {
      // Full data received, create Message and push to queue
402
403
404
405
      Message msg;
      msg.data = buffer;
      msg.size = data_size;
      msg.deallocator = DefaultMessageDeleter;
406
407
408
409
410
411
412
      queues[sender_id]->Add(msg);

      // Reset recv context
      data_size = -1;

      // Signal queue semaphore
      queue_sem->Post();
413
414
415
416
417
418
    }
  }
}

}  // namespace network
}  // namespace dgl