socket_communicator.cc 14.8 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file communicator.cc
 * @brief SocketCommunicator for DGL distributed training.
5
 */
6
#include "socket_communicator.h"
7

8
#include <dmlc/logging.h>
9
#include <stdlib.h>
10
#include <string.h>
11
#include <time.h>
12

13
#include <memory>
14

15
#include "../../c_api_common.h"
16
#include "socket_pool.h"
17
18
19

#ifdef _WIN32
#include <windows.h>
20
#else  // !_WIN32
21
22
23
24
25
26
#include <unistd.h>
#endif  // _WIN32

namespace dgl {
namespace network {

27
28
/////////////////////////////////////// SocketSender
//////////////////////////////////////////////
29

30
bool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) {
31
32
33
34
35
36
37
  if (recv_id < 0) {
    LOG(FATAL) << "recv_id cannot be a negative number.";
  }
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
38
  if (substring[0] != "tcp:" || substring.size() != 2) {
39
40
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
41
               << "e.g, 'tcp://127.0.0.1:50051'. ";
42
43
44
45
46
47
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
48
               << "e.g, 'tcp://127.0.0.1:50051'. ";
49
50
51
52
53
  }
  IPAddr address;
  address.ip = ip_and_port[0];
  address.port = std::stoi(ip_and_port[1]);
  receiver_addrs_[recv_id] = address;
54
55

  return true;
56
57
}

58
bool SocketSender::ConnectReceiverFinalize(const int max_try_times) {
59
  // Create N sockets for Receiver
60
61
62
63
64
  int receiver_count = static_cast<int>(receiver_addrs_.size());
  if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) {
    max_thread_count_ = receiver_count;
  }
  sockets_.resize(max_thread_count_);
65
  for (const auto& r : receiver_addrs_) {
66
67
68
69
    int receiver_id = r.first;
    int thread_id = receiver_id % max_thread_count_;
    sockets_[thread_id][receiver_id] = std::make_shared<TCPSocket>();
    TCPSocket* client_socket = sockets_[thread_id][receiver_id].get();
70
71
    bool bo = false;
    int try_count = 0;
72
73
    const char* ip = r.second.ip.c_str();
    int port = r.second.port;
74
    while (bo == false && try_count < max_try_times) {
75
      if (client_socket->Connect(ip, port)) {
76
77
        bo = true;
      } else {
78
        if (try_count % 200 == 0 && try_count != 0) {
79
80
          // every 600 seconds show this message
          LOG(INFO) << "Trying to connect receiver: " << ip << ":" << port;
81
        }
82
        try_count++;
83
        std::this_thread::sleep_for(std::chrono::seconds(3));
84
85
86
87
88
      }
    }
    if (bo == false) {
      return bo;
    }
89
90
91
92
  }

  for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
    msg_queue_.push_back(std::make_shared<MessageQueue>(queue_size_));
93
    // Create a new thread for this socket connection
94
    threads_.push_back(std::make_shared<std::thread>(
95
        SendLoop, sockets_[thread_id], msg_queue_[thread_id]));
96
  }
97

98
99
100
  return true;
}

101
102
103
104
105
void SocketSender::Send(const rpc::RPCMessage& msg, int recv_id) {
  std::shared_ptr<std::string> zerocopy_blob(new std::string());
  StreamWithBuffer zc_write_strm(zerocopy_blob.get(), true);
  zc_write_strm.Write(msg);
  int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
106
107
  zerocopy_blob->append(
      reinterpret_cast<char*>(&nonempty_ndarray_count), sizeof(int32_t));
108
109
110
111
  Message rpc_meta_msg;
  rpc_meta_msg.data = const_cast<char*>(zerocopy_blob->data());
  rpc_meta_msg.size = zerocopy_blob->size();
  rpc_meta_msg.deallocator = [zerocopy_blob](Message*) {};
112
  CHECK_EQ(Send(rpc_meta_msg, recv_id), ADD_SUCCESS);
113
114
115
116
117
118
119
120
121
122
  // send real ndarray data
  for (auto ptr : zc_write_strm.buffer_list()) {
    Message ndarray_data_msg;
    ndarray_data_msg.data = reinterpret_cast<char*>(ptr.data);
    if (ptr.size == 0) {
      LOG(FATAL) << "Cannot send a empty NDArray.";
    }
    ndarray_data_msg.size = ptr.size;
    NDArray tensor = ptr.tensor;
    ndarray_data_msg.deallocator = [tensor](Message*) {};
123
    CHECK_EQ(Send(ndarray_data_msg, recv_id), ADD_SUCCESS);
124
125
126
  }
}

127
128
129
130
STATUS SocketSender::Send(Message msg, int recv_id) {
  CHECK_NOTNULL(msg.data);
  CHECK_GT(msg.size, 0);
  CHECK_GE(recv_id, 0);
131
  msg.receiver_id = recv_id;
132
  // Add data message to message queue
133
  STATUS code = msg_queue_[recv_id % max_thread_count_]->Add(msg);
134
  return code;
135
136
137
}

void SocketSender::Finalize() {
138
  // Send a signal to tell the msg_queue to finish its job
139
  for (int i = 0; i < max_thread_count_; ++i) {
140
    // wait until queue is empty
141
142
    auto& mq = msg_queue_[i];
    while (mq->Empty() == false) {
143
      std::this_thread::sleep_for(std::chrono::seconds(1));
144
    }
145
146
147
    // All queues have only one producer, which is main thread, so
    // the producerID argument here should be zero.
    mq->SignalFinished(0);
148
149
150
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
151
    thread->join();
152
153
  }
  // Clear all sockets
154
  for (auto& group_sockets_ : sockets_) {
155
    for (auto& socket : group_sockets_) {
156
157
158
159
160
161
162
163
164
165
166
      socket.second->Close();
    }
  }
}

void SendCore(Message msg, TCPSocket* socket) {
  // First send the size
  // If exit == true, we will send zero size to reciever
  int64_t sent_bytes = 0;
  while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
    int64_t max_len = sizeof(int64_t) - sent_bytes;
167
168
    int64_t tmp =
        socket->Send(reinterpret_cast<char*>(&msg.size) + sent_bytes, max_len);
169
170
171
172
173
174
175
    CHECK_NE(tmp, -1);
    sent_bytes += tmp;
  }
  // Then send the data
  sent_bytes = 0;
  while (sent_bytes < msg.size) {
    int64_t max_len = msg.size - sent_bytes;
176
    int64_t tmp = socket->Send(msg.data + sent_bytes, max_len);
177
178
179
180
181
182
    CHECK_NE(tmp, -1);
    sent_bytes += tmp;
  }
  // delete msg
  if (msg.deallocator != nullptr) {
    msg.deallocator(&msg);
183
  }
184
185
}

186
void SocketSender::SendLoop(
187
188
    std::unordered_map<int, std::shared_ptr<TCPSocket>> sockets,
    std::shared_ptr<MessageQueue> queue) {
189
  for (;;) {
190
191
192
193
    Message msg;
    STATUS code = queue->Remove(&msg);
    if (code == QUEUE_CLOSE) {
      msg.size = 0;  // send an end-signal to receiver
194
195
196
197
      for (auto& socket : sockets) {
        SendCore(msg, socket.second.get());
      }
      break;
198
    }
199
    SendCore(msg, sockets[msg.receiver_id].get());
200
  }
201
202
}

203
204
205
206
/////////////////////////////////////// SocketReceiver
//////////////////////////////////////////////
bool SocketReceiver::Wait(
    const std::string& addr, int num_sender, bool blocking) {
207
  CHECK_GT(num_sender, 0);
208
  CHECK_EQ(blocking, true);
209
210
211
212
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
213
  if (substring[0] != "tcp:" || substring.size() != 2) {
214
215
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
216
               << "e.g, 'tcp://127.0.0.1:50051'. ";
217
218
219
220
221
222
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
223
               << "e.g, 'tcp://127.0.0.1:50051'. ";
224
225
226
227
  }
  std::string ip = ip_and_port[0];
  int port = stoi(ip_and_port[1]);
  // Initialize message queue for each connection
228
  num_sender_ = num_sender;
229
230
#ifdef USE_EPOLL
  if (max_thread_count_ == 0 || max_thread_count_ > num_sender_) {
231
    max_thread_count_ = num_sender_;
232
  }
233
234
235
#else
  max_thread_count_ = num_sender_;
#endif
236
237
  // Initialize socket and socket-thread
  server_socket_ = new TCPSocket();
238
  // Bind socket
239
  if (server_socket_->Bind(ip.c_str(), port) == false) {
240
    LOG(FATAL) << "Cannot bind to " << ip << ":" << port;
241
  }
242

243
  // Listen
244
  if (server_socket_->Listen(kMaxConnection) == false) {
245
    LOG(FATAL) << "Cannot listen on " << ip << ":" << port;
246
247
248
249
  }
  // Accept all sender sockets
  std::string accept_ip;
  int accept_port;
250
  sockets_.resize(max_thread_count_);
251
  for (int i = 0; i < num_sender_; ++i) {
252
253
254
255
    int thread_id = i % max_thread_count_;
    auto socket = std::make_shared<TCPSocket>();
    sockets_[thread_id][i] = socket;
    msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
256
257
    if (server_socket_->Accept(socket.get(), &accept_ip, &accept_port) ==
        false) {
258
      LOG(WARNING) << "Error on accept socket.";
259
260
      return false;
    }
261
262
263
264
  }
  mq_iter_ = msg_queue_.begin();

  for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
265
    // create new thread for each socket
266
    threads_.push_back(std::make_shared<std::thread>(
267
        RecvLoop, sockets_[thread_id], msg_queue_, &queue_sem_));
268
269
270
271
272
  }

  return true;
}

273
rpc::RPCStatus SocketReceiver::Recv(rpc::RPCMessage* msg, int timeout) {
274
275
  Message rpc_meta_msg;
  int send_id;
276
277
278
279
280
281
282
  auto status = Recv(&rpc_meta_msg, &send_id, timeout);
  if (status == QUEUE_EMPTY) {
    DLOG(WARNING) << "Timed out when trying to receive rpc meta data after "
                  << timeout << " milliseconds.";
    return rpc::kRPCTimeOut;
  }
  CHECK_EQ(status, REMOVE_SUCCESS);
283
  char* count_ptr = rpc_meta_msg.data + rpc_meta_msg.size - sizeof(int32_t);
284
285
286
287
288
  int32_t nonempty_ndarray_count = *(reinterpret_cast<int32_t*>(count_ptr));
  // Recv real ndarray data
  std::vector<void*> buffer_list(nonempty_ndarray_count);
  for (int i = 0; i < nonempty_ndarray_count; ++i) {
    Message ndarray_data_msg;
289
290
291
292
293
294
295
296
297
298
299
    // As meta message has been received, data message is always expected unless
    // connection is closed.
    STATUS status;
    do {
      status = RecvFrom(&ndarray_data_msg, send_id, timeout);
      if (status == QUEUE_EMPTY) {
        DLOG(WARNING)
            << "Timed out when trying to receive rpc ndarray data after "
            << timeout << " milliseconds.";
      }
    } while (status == QUEUE_EMPTY);
300
    CHECK_EQ(status, REMOVE_SUCCESS);
301
302
    buffer_list[i] = ndarray_data_msg.data;
  }
303
304
  StreamWithBuffer zc_read_strm(
      rpc_meta_msg.data, rpc_meta_msg.size - sizeof(int32_t), buffer_list);
305
306
  zc_read_strm.Read(msg);
  rpc_meta_msg.deallocator(&rpc_meta_msg);
307
  return rpc::kRPCSuccess;
308
309
}

310
STATUS SocketReceiver::Recv(Message* msg, int* send_id, int timeout) {
311
312
313
  // queue_sem_ is a semaphore indicating how many elements in multiple
  // message queues.
  // When calling queue_sem_.Wait(), this Recv will be suspended until
314
315
316
317
318
  // queue_sem_ > 0 or specified timeout expires, decrease queue_sem_ by 1,
  // then start to fetch a message.
  if (!queue_sem_.TimedWait(timeout)) {
    return QUEUE_EMPTY;
  }
319
  for (;;) {
320
321
    for (; mq_iter_ != msg_queue_.end(); ++mq_iter_) {
      STATUS code = mq_iter_->second->Remove(msg, false);
322
323
324
      if (code == QUEUE_EMPTY) {
        continue;  // jump to the next queue
      } else {
325
326
        *send_id = mq_iter_->first;
        ++mq_iter_;
327
328
329
        return code;
      }
    }
330
    mq_iter_ = msg_queue_.begin();
331
  }
332
333
334
  LOG(ERROR)
      << "Failed to remove message from queue due to unexpected queue status.";
  return QUEUE_CLOSE;
335
336
}

337
STATUS SocketReceiver::RecvFrom(Message* msg, int send_id, int timeout) {
338
  // Get message from specified message queue
339
340
341
  if (!queue_sem_.TimedWait(timeout)) {
    return QUEUE_EMPTY;
  }
342
343
344
345
346
347
348
349
350
  STATUS code = msg_queue_[send_id]->Remove(msg);
  return code;
}

void SocketReceiver::Finalize() {
  // Send a signal to tell the message queue to finish its job
  for (auto& mq : msg_queue_) {
    // wait until queue is empty
    while (mq.second->Empty() == false) {
351
      std::this_thread::sleep_for(std::chrono::seconds(1));
352
    }
353
    mq.second->SignalFinished(mq.first);
354
355
356
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
357
    thread->join();
358
359
  }
  // Clear all sockets
360
361
362
363
  for (auto& group_sockets : sockets_) {
    for (auto& socket : group_sockets) {
      socket.second->Close();
    }
364
  }
Chao Ma's avatar
Chao Ma committed
365
366
  server_socket_->Close();
  delete server_socket_;
367
368
}

369
370
371
372
373
374
int64_t RecvDataSize(TCPSocket* socket) {
  int64_t received_bytes = 0;
  int64_t data_size = 0;
  while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) {
    int64_t max_len = sizeof(int64_t) - received_bytes;
    int64_t tmp = socket->Receive(
375
        reinterpret_cast<char*>(&data_size) + received_bytes, max_len);
376
377
378
379
380
381
    if (tmp == -1) {
      if (received_bytes > 0) {
        // We want to finish reading full data_size
        continue;
      }
      return -1;
382
    }
383
384
385
386
387
    received_bytes += tmp;
  }
  return data_size;
}

388
389
390
void RecvData(
    TCPSocket* socket, char* buffer, const int64_t& data_size,
    int64_t* received_bytes) {
391
392
393
394
395
  while (*received_bytes < data_size) {
    int64_t max_len = data_size - *received_bytes;
    int64_t tmp = socket->Receive(buffer + *received_bytes, max_len);
    if (tmp == -1) {
      // Socket not ready, no more data to read
396
      return;
397
398
399
400
401
402
    }
    *received_bytes += tmp;
  }
}

void SocketReceiver::RecvLoop(
403
404
405
406
407
408
409
    std::unordered_map<
        int /* Sender (virtual) ID */, std::shared_ptr<TCPSocket>>
        sockets,
    std::unordered_map<
        int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>
        queues,
    runtime::Semaphore* queue_sem) {
410
411
412
  std::unordered_map<int, std::unique_ptr<RecvContext>> recv_contexts;
  SocketPool socket_pool;
  for (auto& socket : sockets) {
413
    auto& sender_id = socket.first;
414
415
416
417
418
419
420
421
422
423
424
425
426
    socket_pool.AddSocket(socket.second, sender_id);
    recv_contexts[sender_id] = std::unique_ptr<RecvContext>(new RecvContext());
  }

  // Main loop to receive messages
  for (;;) {
    int sender_id;
    // Get active socket using epoll
    std::shared_ptr<TCPSocket> socket = socket_pool.GetActiveSocket(&sender_id);
    if (queues[sender_id]->EmptyAndNoMoreAdd()) {
      // This sender has already stopped
      if (socket_pool.RemoveSocket(socket) == 0) {
        return;
427
      }
428
429
430
431
432
      continue;
    }

    // Nonblocking socket might be interrupted at any point. So we need to
    // store the partially received data
433
434
435
    std::unique_ptr<RecvContext>& ctx = recv_contexts[sender_id];
    int64_t& data_size = ctx->data_size;
    int64_t& received_bytes = ctx->received_bytes;
436
437
438
439
440
441
442
443
    char*& buffer = ctx->buffer;

    if (data_size == -1) {
      // This is a new message, so receive the data size first
      data_size = RecvDataSize(socket.get());
      if (data_size > 0) {
        try {
          buffer = new char[data_size];
444
        } catch (const std::bad_alloc&) {
445
446
447
448
449
450
451
452
453
          LOG(FATAL) << "Cannot allocate enough memory for message, "
                     << "(message size: " << data_size << ")";
        }
        received_bytes = 0;
      } else if (data_size == 0) {
        // Received stop signal
        if (socket_pool.RemoveSocket(socket) == 0) {
          return;
        }
454
      }
455
456
457
458
459
    }

    RecvData(socket.get(), buffer, data_size, &received_bytes);
    if (received_bytes >= data_size) {
      // Full data received, create Message and push to queue
460
461
462
463
      Message msg;
      msg.data = buffer;
      msg.size = data_size;
      msg.deallocator = DefaultMessageDeleter;
464
465
466
467
468
469
470
      queues[sender_id]->Add(msg);

      // Reset recv context
      data_size = -1;

      // Signal queue semaphore
      queue_sem->Post();
471
472
473
474
475
476
    }
  }
}

}  // namespace network
}  // namespace dgl