socket_communicator.cc 14.8 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2019 by Contributors
 * \file communicator.cc
 * \brief SocketCommunicator for DGL distributed training.
 */
#include <dmlc/logging.h>

8
9
10
#include <string.h>
#include <stdlib.h>
#include <time.h>
11
#include <memory>
12

13
14
#include "socket_communicator.h"
#include "../../c_api_common.h"
15
#include "socket_pool.h"
16
17
18
19
20
21
22
23
24
25
26

#ifdef _WIN32
#include <windows.h>
#else   // !_WIN32
#include <unistd.h>
#endif  // _WIN32

namespace dgl {
namespace network {


27
28
29
/////////////////////////////////////// SocketSender ///////////////////////////////////////////


30
bool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) {
31
32
33
34
35
36
37
  if (recv_id < 0) {
    LOG(FATAL) << "recv_id cannot be a negative number.";
  }
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
38
  if (substring[0] != "tcp:" || substring.size() != 2) {
39
40
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
41
               << "e.g, 'tcp://127.0.0.1:50051'. ";
42
43
44
45
46
47
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
48
               << "e.g, 'tcp://127.0.0.1:50051'. ";
49
50
51
52
53
  }
  IPAddr address;
  address.ip = ip_and_port[0];
  address.port = std::stoi(ip_and_port[1]);
  receiver_addrs_[recv_id] = address;
54
55

  return true;
56
57
}

58
bool SocketSender::ConnectReceiverFinalize(const int max_try_times) {
59
  // Create N sockets for Receiver
60
61
62
63
64
  int receiver_count = static_cast<int>(receiver_addrs_.size());
  if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) {
    max_thread_count_ = receiver_count;
  }
  sockets_.resize(max_thread_count_);
65
  for (const auto& r : receiver_addrs_) {
66
67
68
69
    int receiver_id = r.first;
    int thread_id = receiver_id % max_thread_count_;
    sockets_[thread_id][receiver_id] = std::make_shared<TCPSocket>();
    TCPSocket* client_socket = sockets_[thread_id][receiver_id].get();
70
71
    bool bo = false;
    int try_count = 0;
72
73
    const char* ip = r.second.ip.c_str();
    int port = r.second.port;
74
    while (bo == false && try_count < max_try_times) {
75
      if (client_socket->Connect(ip, port)) {
76
77
        bo = true;
      } else {
78
        if (try_count % 200 == 0 && try_count != 0) {
79
80
          // every 600 seconds show this message
          LOG(INFO) << "Trying to connect receiver: " << ip << ":" << port;
81
        }
82
        try_count++;
83
        std::this_thread::sleep_for(std::chrono::seconds(3));
84
85
86
87
88
      }
    }
    if (bo == false) {
      return bo;
    }
89
90
91
92
  }

  for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
    msg_queue_.push_back(std::make_shared<MessageQueue>(queue_size_));
93
    // Create a new thread for this socket connection
94
    threads_.push_back(std::make_shared<std::thread>(
95
      SendLoop,
96
97
      sockets_[thread_id],
      msg_queue_[thread_id]));
98
  }
99

100
101
102
  return true;
}

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
void SocketSender::Send(const rpc::RPCMessage& msg, int recv_id) {
  std::shared_ptr<std::string> zerocopy_blob(new std::string());
  StreamWithBuffer zc_write_strm(zerocopy_blob.get(), true);
  zc_write_strm.Write(msg);
  int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
  zerocopy_blob->append(reinterpret_cast<char*>(&nonempty_ndarray_count),
                        sizeof(int32_t));
  Message rpc_meta_msg;
  rpc_meta_msg.data = const_cast<char*>(zerocopy_blob->data());
  rpc_meta_msg.size = zerocopy_blob->size();
  rpc_meta_msg.deallocator = [zerocopy_blob](Message*) {};
  CHECK_EQ(Send(
    rpc_meta_msg, recv_id), ADD_SUCCESS);
  // send real ndarray data
  for (auto ptr : zc_write_strm.buffer_list()) {
    Message ndarray_data_msg;
    ndarray_data_msg.data = reinterpret_cast<char*>(ptr.data);
    if (ptr.size == 0) {
      LOG(FATAL) << "Cannot send a empty NDArray.";
    }
    ndarray_data_msg.size = ptr.size;
    NDArray tensor = ptr.tensor;
    ndarray_data_msg.deallocator = [tensor](Message*) {};
    CHECK_EQ(Send(
      ndarray_data_msg, recv_id), ADD_SUCCESS);
  }
}

131
132
133
134
STATUS SocketSender::Send(Message msg, int recv_id) {
  CHECK_NOTNULL(msg.data);
  CHECK_GT(msg.size, 0);
  CHECK_GE(recv_id, 0);
135
  msg.receiver_id = recv_id;
136
  // Add data message to message queue
137
  STATUS code = msg_queue_[recv_id % max_thread_count_]->Add(msg);
138
  return code;
139
140
141
}

void SocketSender::Finalize() {
142
  // Send a signal to tell the msg_queue to finish its job
143
  for (int i = 0; i < max_thread_count_; ++i) {
144
    // wait until queue is empty
145
146
    auto& mq = msg_queue_[i];
    while (mq->Empty() == false) {
147
      std::this_thread::sleep_for(std::chrono::seconds(1));
148
    }
149
150
151
    // All queues have only one producer, which is main thread, so
    // the producerID argument here should be zero.
    mq->SignalFinished(0);
152
153
154
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
155
    thread->join();
156
157
  }
  // Clear all sockets
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
  for (auto& group_sockets_ : sockets_) {
    for (auto &socket : group_sockets_) {
      socket.second->Close();
    }
  }
}

void SendCore(Message msg, TCPSocket* socket) {
  // First send the size
  // If exit == true, we will send zero size to reciever
  int64_t sent_bytes = 0;
  while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
    int64_t max_len = sizeof(int64_t) - sent_bytes;
    int64_t tmp = socket->Send(
      reinterpret_cast<char*>(&msg.size) + sent_bytes,
      max_len);
    CHECK_NE(tmp, -1);
    sent_bytes += tmp;
  }
  // Then send the data
  sent_bytes = 0;
  while (sent_bytes < msg.size) {
    int64_t max_len = msg.size - sent_bytes;
    int64_t tmp = socket->Send(msg.data+sent_bytes, max_len);
    CHECK_NE(tmp, -1);
    sent_bytes += tmp;
  }
  // delete msg
  if (msg.deallocator != nullptr) {
    msg.deallocator(&msg);
188
  }
189
190
}

191
192
193
194
void SocketSender::SendLoop(
  std::unordered_map<int, std::shared_ptr<TCPSocket>> sockets,
  std::shared_ptr<MessageQueue> queue) {
  for (;;) {
195
196
197
198
    Message msg;
    STATUS code = queue->Remove(&msg);
    if (code == QUEUE_CLOSE) {
      msg.size = 0;  // send an end-signal to receiver
199
200
201
202
      for (auto& socket : sockets) {
        SendCore(msg, socket.second.get());
      }
      break;
203
    }
204
    SendCore(msg, sockets[msg.receiver_id].get());
205
  }
206
207
}

208
/////////////////////////////////////// SocketReceiver ///////////////////////////////////////////
209
bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking) {
210
  CHECK_GT(num_sender, 0);
211
  CHECK_EQ(blocking, true);
212
213
214
215
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
216
  if (substring[0] != "tcp:" || substring.size() != 2) {
217
218
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
219
               << "e.g, 'tcp://127.0.0.1:50051'. ";
220
221
222
223
224
225
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
226
               << "e.g, 'tcp://127.0.0.1:50051'. ";
227
228
229
230
  }
  std::string ip = ip_and_port[0];
  int port = stoi(ip_and_port[1]);
  // Initialize message queue for each connection
231
  num_sender_ = num_sender;
232
233
234
#ifdef USE_EPOLL
  if (max_thread_count_ == 0 || max_thread_count_ > num_sender_) {
      max_thread_count_ = num_sender_;
235
  }
236
237
238
#else
  max_thread_count_ = num_sender_;
#endif
239
240
  // Initialize socket and socket-thread
  server_socket_ = new TCPSocket();
241
  // Bind socket
242
  if (server_socket_->Bind(ip.c_str(), port) == false) {
243
    LOG(FATAL) << "Cannot bind to " << ip << ":" << port;
244
  }
245

246
  // Listen
247
  if (server_socket_->Listen(kMaxConnection) == false) {
248
    LOG(FATAL) << "Cannot listen on " << ip << ":" << port;
249
250
251
252
  }
  // Accept all sender sockets
  std::string accept_ip;
  int accept_port;
253
  sockets_.resize(max_thread_count_);
254
  for (int i = 0; i < num_sender_; ++i) {
255
256
257
258
259
    int thread_id = i % max_thread_count_;
    auto socket = std::make_shared<TCPSocket>();
    sockets_[thread_id][i] = socket;
    msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
    if (server_socket_->Accept(socket.get(), &accept_ip, &accept_port) == false) {
260
      LOG(WARNING) << "Error on accept socket.";
261
262
      return false;
    }
263
264
265
266
  }
  mq_iter_ = msg_queue_.begin();

  for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
267
    // create new thread for each socket
268
    threads_.push_back(std::make_shared<std::thread>(
269
      RecvLoop,
270
271
272
      sockets_[thread_id],
      msg_queue_,
      &queue_sem_));
273
274
275
276
277
  }

  return true;
}

278
rpc::RPCStatus SocketReceiver::Recv(rpc::RPCMessage* msg, int timeout) {
279
280
  Message rpc_meta_msg;
  int send_id;
281
282
283
284
285
286
287
  auto status = Recv(&rpc_meta_msg, &send_id, timeout);
  if (status == QUEUE_EMPTY) {
    DLOG(WARNING) << "Timed out when trying to receive rpc meta data after "
                  << timeout << " milliseconds.";
    return rpc::kRPCTimeOut;
  }
  CHECK_EQ(status, REMOVE_SUCCESS);
288
289
290
291
292
293
  char* count_ptr = rpc_meta_msg.data+rpc_meta_msg.size-sizeof(int32_t);
  int32_t nonempty_ndarray_count = *(reinterpret_cast<int32_t*>(count_ptr));
  // Recv real ndarray data
  std::vector<void*> buffer_list(nonempty_ndarray_count);
  for (int i = 0; i < nonempty_ndarray_count; ++i) {
    Message ndarray_data_msg;
294
295
296
297
298
299
300
301
302
    status = RecvFrom(&ndarray_data_msg, send_id, timeout);
    if (status == QUEUE_EMPTY) {
      // As we cannot handle this timeout for now, let's treat it as fatal
      // error.
      LOG(FATAL) << "Timed out when trying to receive rpc ndarray data after "
                 << timeout << " milliseconds.";
      return rpc::kRPCTimeOut;
    }
    CHECK_EQ(status, REMOVE_SUCCESS);
303
304
305
306
307
    buffer_list[i] = ndarray_data_msg.data;
  }
  StreamWithBuffer zc_read_strm(rpc_meta_msg.data, rpc_meta_msg.size-sizeof(int32_t), buffer_list);
  zc_read_strm.Read(msg);
  rpc_meta_msg.deallocator(&rpc_meta_msg);
308
  return rpc::kRPCSuccess;
309
310
}

311
STATUS SocketReceiver::Recv(Message* msg, int* send_id, int timeout) {
312
313
314
  // queue_sem_ is a semaphore indicating how many elements in multiple
  // message queues.
  // When calling queue_sem_.Wait(), this Recv will be suspended until
315
316
317
318
319
  // queue_sem_ > 0 or specified timeout expires, decrease queue_sem_ by 1,
  // then start to fetch a message.
  if (!queue_sem_.TimedWait(timeout)) {
    return QUEUE_EMPTY;
  }
320
  for (;;) {
321
322
    for (; mq_iter_ != msg_queue_.end(); ++mq_iter_) {
      STATUS code = mq_iter_->second->Remove(msg, false);
323
324
325
      if (code == QUEUE_EMPTY) {
        continue;  // jump to the next queue
      } else {
326
327
        *send_id = mq_iter_->first;
        ++mq_iter_;
328
329
330
        return code;
      }
    }
331
    mq_iter_ = msg_queue_.begin();
332
  }
333
334
335
  LOG(ERROR)
      << "Failed to remove message from queue due to unexpected queue status.";
  return QUEUE_CLOSE;
336
337
}

338
STATUS SocketReceiver::RecvFrom(Message* msg, int send_id, int timeout) {
339
  // Get message from specified message queue
340
341
342
  if (!queue_sem_.TimedWait(timeout)) {
    return QUEUE_EMPTY;
  }
343
344
345
346
347
348
349
350
351
  STATUS code = msg_queue_[send_id]->Remove(msg);
  return code;
}

void SocketReceiver::Finalize() {
  // Send a signal to tell the message queue to finish its job
  for (auto& mq : msg_queue_) {
    // wait until queue is empty
    while (mq.second->Empty() == false) {
352
        std::this_thread::sleep_for(std::chrono::seconds(1));
353
    }
354
    mq.second->SignalFinished(mq.first);
355
356
357
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
358
    thread->join();
359
360
  }
  // Clear all sockets
361
362
363
364
  for (auto& group_sockets : sockets_) {
    for (auto& socket : group_sockets) {
      socket.second->Close();
    }
365
  }
Chao Ma's avatar
Chao Ma committed
366
367
  server_socket_->Close();
  delete server_socket_;
368
369
}

370
371
372
373
374
375
376
377
378
379
380
381
382
383
int64_t RecvDataSize(TCPSocket* socket) {
  int64_t received_bytes = 0;
  int64_t data_size = 0;
  while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) {
    int64_t max_len = sizeof(int64_t) - received_bytes;
    int64_t tmp = socket->Receive(
      reinterpret_cast<char*>(&data_size) + received_bytes,
      max_len);
    if (tmp == -1) {
      if (received_bytes > 0) {
        // We want to finish reading full data_size
        continue;
      }
      return -1;
384
    }
385
386
387
388
389
390
391
392
393
394
395
396
    received_bytes += tmp;
  }
  return data_size;
}

void RecvData(TCPSocket* socket, char* buffer, const int64_t &data_size,
  int64_t *received_bytes) {
  while (*received_bytes < data_size) {
    int64_t max_len = data_size - *received_bytes;
    int64_t tmp = socket->Receive(buffer + *received_bytes, max_len);
    if (tmp == -1) {
      // Socket not ready, no more data to read
397
      return;
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
    }
    *received_bytes += tmp;
  }
}

void SocketReceiver::RecvLoop(
  std::unordered_map<int /* Sender (virtual) ID */,
    std::shared_ptr<TCPSocket>> sockets,
  std::unordered_map<int /* Sender (virtual) ID */,
    std::shared_ptr<MessageQueue>> queues,
  runtime::Semaphore *queue_sem) {
  std::unordered_map<int, std::unique_ptr<RecvContext>> recv_contexts;
  SocketPool socket_pool;
  for (auto& socket : sockets) {
    auto &sender_id = socket.first;
    socket_pool.AddSocket(socket.second, sender_id);
    recv_contexts[sender_id] = std::unique_ptr<RecvContext>(new RecvContext());
  }

  // Main loop to receive messages
  for (;;) {
    int sender_id;
    // Get active socket using epoll
    std::shared_ptr<TCPSocket> socket = socket_pool.GetActiveSocket(&sender_id);
    if (queues[sender_id]->EmptyAndNoMoreAdd()) {
      // This sender has already stopped
      if (socket_pool.RemoveSocket(socket) == 0) {
        return;
426
      }
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
      continue;
    }

    // Nonblocking socket might be interrupted at any point. So we need to
    // store the partially received data
    std::unique_ptr<RecvContext> &ctx = recv_contexts[sender_id];
    int64_t &data_size = ctx->data_size;
    int64_t &received_bytes = ctx->received_bytes;
    char*& buffer = ctx->buffer;

    if (data_size == -1) {
      // This is a new message, so receive the data size first
      data_size = RecvDataSize(socket.get());
      if (data_size > 0) {
        try {
          buffer = new char[data_size];
        } catch(const std::bad_alloc&) {
          LOG(FATAL) << "Cannot allocate enough memory for message, "
                     << "(message size: " << data_size << ")";
        }
        received_bytes = 0;
      } else if (data_size == 0) {
        // Received stop signal
        if (socket_pool.RemoveSocket(socket) == 0) {
          return;
        }
453
      }
454
455
456
457
458
    }

    RecvData(socket.get(), buffer, data_size, &received_bytes);
    if (received_bytes >= data_size) {
      // Full data received, create Message and push to queue
459
460
461
462
      Message msg;
      msg.data = buffer;
      msg.size = data_size;
      msg.deallocator = DefaultMessageDeleter;
463
464
465
466
467
468
469
      queues[sender_id]->Add(msg);

      // Reset recv context
      data_size = -1;

      // Signal queue semaphore
      queue_sem->Post();
470
471
472
473
474
475
    }
  }
}

}  // namespace network
}  // namespace dgl