socket_communicator.cc 14.8 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file communicator.cc
 * @brief SocketCommunicator for DGL distributed training.
5
 */
6
#include "socket_communicator.h"
7

8
#include <dmlc/logging.h>
9
#include <stdlib.h>
10
#include <string.h>
11
#include <time.h>
12

13
#include <memory>
14

15
#include "../../c_api_common.h"
16
#include "socket_pool.h"
17
18
19

#ifdef _WIN32
#include <windows.h>
20
#else  // !_WIN32
21
22
23
24
25
26
#include <unistd.h>
#endif  // _WIN32

namespace dgl {
namespace network {

27
28
/////////////////////////////////////// SocketSender
//////////////////////////////////////////////
29

30
bool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) {
31
32
33
34
35
36
37
  if (recv_id < 0) {
    LOG(FATAL) << "recv_id cannot be a negative number.";
  }
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
38
  if (substring[0] != "tcp:" || substring.size() != 2) {
39
40
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
41
               << "e.g, 'tcp://127.0.0.1:50051'. ";
42
43
44
45
46
47
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
48
               << "e.g, 'tcp://127.0.0.1:50051'. ";
49
50
51
52
53
  }
  IPAddr address;
  address.ip = ip_and_port[0];
  address.port = std::stoi(ip_and_port[1]);
  receiver_addrs_[recv_id] = address;
54
55

  return true;
56
57
}

58
bool SocketSender::ConnectReceiverFinalize(const int max_try_times) {
59
  // Create N sockets for Receiver
60
61
62
63
64
  int receiver_count = static_cast<int>(receiver_addrs_.size());
  if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) {
    max_thread_count_ = receiver_count;
  }
  sockets_.resize(max_thread_count_);
65
  for (const auto& r : receiver_addrs_) {
66
67
68
69
    int receiver_id = r.first;
    int thread_id = receiver_id % max_thread_count_;
    sockets_[thread_id][receiver_id] = std::make_shared<TCPSocket>();
    TCPSocket* client_socket = sockets_[thread_id][receiver_id].get();
70
71
    bool bo = false;
    int try_count = 0;
72
73
    const char* ip = r.second.ip.c_str();
    int port = r.second.port;
74
    while (bo == false && try_count < max_try_times) {
75
      if (client_socket->Connect(ip, port)) {
76
77
        bo = true;
      } else {
78
        if (try_count % 200 == 0 && try_count != 0) {
79
80
          // every 600 seconds show this message
          LOG(INFO) << "Trying to connect receiver: " << ip << ":" << port;
81
        }
82
        try_count++;
83
        std::this_thread::sleep_for(std::chrono::seconds(3));
84
85
86
87
88
      }
    }
    if (bo == false) {
      return bo;
    }
89
90
91
92
  }

  for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
    msg_queue_.push_back(std::make_shared<MessageQueue>(queue_size_));
93
    // Create a new thread for this socket connection
94
    threads_.push_back(std::make_shared<std::thread>(
95
        SendLoop, sockets_[thread_id], msg_queue_[thread_id]));
96
  }
97

98
99
100
  return true;
}

101
102
103
104
105
void SocketSender::Send(const rpc::RPCMessage& msg, int recv_id) {
  std::shared_ptr<std::string> zerocopy_blob(new std::string());
  StreamWithBuffer zc_write_strm(zerocopy_blob.get(), true);
  zc_write_strm.Write(msg);
  int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
106
107
  zerocopy_blob->append(
      reinterpret_cast<char*>(&nonempty_ndarray_count), sizeof(int32_t));
108
109
110
111
  Message rpc_meta_msg;
  rpc_meta_msg.data = const_cast<char*>(zerocopy_blob->data());
  rpc_meta_msg.size = zerocopy_blob->size();
  rpc_meta_msg.deallocator = [zerocopy_blob](Message*) {};
112
  CHECK_EQ(Send(rpc_meta_msg, recv_id), ADD_SUCCESS);
113
114
115
116
117
118
119
120
121
122
  // send real ndarray data
  for (auto ptr : zc_write_strm.buffer_list()) {
    Message ndarray_data_msg;
    ndarray_data_msg.data = reinterpret_cast<char*>(ptr.data);
    if (ptr.size == 0) {
      LOG(FATAL) << "Cannot send a empty NDArray.";
    }
    ndarray_data_msg.size = ptr.size;
    NDArray tensor = ptr.tensor;
    ndarray_data_msg.deallocator = [tensor](Message*) {};
123
    CHECK_EQ(Send(ndarray_data_msg, recv_id), ADD_SUCCESS);
124
125
126
  }
}

127
128
129
130
STATUS SocketSender::Send(Message msg, int recv_id) {
  CHECK_NOTNULL(msg.data);
  CHECK_GT(msg.size, 0);
  CHECK_GE(recv_id, 0);
131
  msg.receiver_id = recv_id;
132
  // Add data message to message queue
133
  STATUS code = msg_queue_[recv_id % max_thread_count_]->Add(msg);
134
  return code;
135
136
137
}

void SocketSender::Finalize() {
138
  // Send a signal to tell the msg_queue to finish its job
139
  for (int i = 0; i < max_thread_count_; ++i) {
140
    // wait until queue is empty
141
142
    auto& mq = msg_queue_[i];
    while (mq->Empty() == false) {
143
      std::this_thread::sleep_for(std::chrono::seconds(1));
144
    }
145
146
147
    // All queues have only one producer, which is main thread, so
    // the producerID argument here should be zero.
    mq->SignalFinished(0);
148
149
150
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
151
    thread->join();
152
153
  }
  // Clear all sockets
154
  for (auto& group_sockets_ : sockets_) {
155
    for (auto& socket : group_sockets_) {
156
157
158
159
160
161
162
163
164
165
166
      socket.second->Close();
    }
  }
}

void SendCore(Message msg, TCPSocket* socket) {
  // First send the size
  // If exit == true, we will send zero size to reciever
  int64_t sent_bytes = 0;
  while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
    int64_t max_len = sizeof(int64_t) - sent_bytes;
167
168
    int64_t tmp =
        socket->Send(reinterpret_cast<char*>(&msg.size) + sent_bytes, max_len);
169
170
171
172
173
174
175
    CHECK_NE(tmp, -1);
    sent_bytes += tmp;
  }
  // Then send the data
  sent_bytes = 0;
  while (sent_bytes < msg.size) {
    int64_t max_len = msg.size - sent_bytes;
176
    int64_t tmp = socket->Send(msg.data + sent_bytes, max_len);
177
178
179
180
181
182
    CHECK_NE(tmp, -1);
    sent_bytes += tmp;
  }
  // delete msg
  if (msg.deallocator != nullptr) {
    msg.deallocator(&msg);
183
  }
184
185
}

186
void SocketSender::SendLoop(
187
188
    std::unordered_map<int, std::shared_ptr<TCPSocket>> sockets,
    std::shared_ptr<MessageQueue> queue) {
189
  for (;;) {
190
191
192
193
    Message msg;
    STATUS code = queue->Remove(&msg);
    if (code == QUEUE_CLOSE) {
      msg.size = 0;  // send an end-signal to receiver
194
195
196
197
      for (auto& socket : sockets) {
        SendCore(msg, socket.second.get());
      }
      break;
198
    }
199
    SendCore(msg, sockets[msg.receiver_id].get());
200
  }
201
202
}

203
204
/////////////////////////////////////// SocketReceiver
//////////////////////////////////////////////
205
bool SocketReceiver::Wait(const std::string& addr, int num_sender) {
206
207
208
209
210
  CHECK_GT(num_sender, 0);
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
211
  if (substring[0] != "tcp:" || substring.size() != 2) {
212
213
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
214
               << "e.g, 'tcp://127.0.0.1:50051'. ";
215
216
217
218
219
220
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
221
               << "e.g, 'tcp://127.0.0.1:50051'. ";
222
223
224
225
  }
  std::string ip = ip_and_port[0];
  int port = stoi(ip_and_port[1]);
  // Initialize message queue for each connection
226
  num_sender_ = num_sender;
227
228
#ifdef USE_EPOLL
  if (max_thread_count_ == 0 || max_thread_count_ > num_sender_) {
229
    max_thread_count_ = num_sender_;
230
  }
231
232
233
#else
  max_thread_count_ = num_sender_;
#endif
234
235
  // Initialize socket and socket-thread
  server_socket_ = new TCPSocket();
236
  // Bind socket
237
  if (server_socket_->Bind(ip.c_str(), port) == false) {
238
    LOG(FATAL) << "Cannot bind to " << ip << ":" << port;
239
  }
240

241
  // Listen
242
  if (server_socket_->Listen(kMaxConnection) == false) {
243
    LOG(FATAL) << "Cannot listen on " << ip << ":" << port;
244
245
246
247
  }
  // Accept all sender sockets
  std::string accept_ip;
  int accept_port;
248
  sockets_.resize(max_thread_count_);
249
  for (int i = 0; i < num_sender_; ++i) {
250
251
252
253
    int thread_id = i % max_thread_count_;
    auto socket = std::make_shared<TCPSocket>();
    sockets_[thread_id][i] = socket;
    msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
254
255
    if (server_socket_->Accept(socket.get(), &accept_ip, &accept_port) ==
        false) {
256
      LOG(WARNING) << "Error on accept socket.";
257
258
      return false;
    }
259
260
261
262
  }
  mq_iter_ = msg_queue_.begin();

  for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
263
    // create new thread for each socket
264
    threads_.push_back(std::make_shared<std::thread>(
265
        RecvLoop, sockets_[thread_id], msg_queue_, &queue_sem_));
266
267
268
269
270
  }

  return true;
}

271
rpc::RPCStatus SocketReceiver::Recv(rpc::RPCMessage* msg, int timeout) {
272
273
  Message rpc_meta_msg;
  int send_id;
274
275
276
277
278
279
280
  auto status = Recv(&rpc_meta_msg, &send_id, timeout);
  if (status == QUEUE_EMPTY) {
    DLOG(WARNING) << "Timed out when trying to receive rpc meta data after "
                  << timeout << " milliseconds.";
    return rpc::kRPCTimeOut;
  }
  CHECK_EQ(status, REMOVE_SUCCESS);
281
  char* count_ptr = rpc_meta_msg.data + rpc_meta_msg.size - sizeof(int32_t);
282
283
284
285
286
  int32_t nonempty_ndarray_count = *(reinterpret_cast<int32_t*>(count_ptr));
  // Recv real ndarray data
  std::vector<void*> buffer_list(nonempty_ndarray_count);
  for (int i = 0; i < nonempty_ndarray_count; ++i) {
    Message ndarray_data_msg;
287
288
289
290
291
292
293
294
295
296
297
    // As meta message has been received, data message is always expected unless
    // connection is closed.
    STATUS status;
    do {
      status = RecvFrom(&ndarray_data_msg, send_id, timeout);
      if (status == QUEUE_EMPTY) {
        DLOG(WARNING)
            << "Timed out when trying to receive rpc ndarray data after "
            << timeout << " milliseconds.";
      }
    } while (status == QUEUE_EMPTY);
298
    CHECK_EQ(status, REMOVE_SUCCESS);
299
300
    buffer_list[i] = ndarray_data_msg.data;
  }
301
302
  StreamWithBuffer zc_read_strm(
      rpc_meta_msg.data, rpc_meta_msg.size - sizeof(int32_t), buffer_list);
303
304
  zc_read_strm.Read(msg);
  rpc_meta_msg.deallocator(&rpc_meta_msg);
305
  return rpc::kRPCSuccess;
306
307
}

308
STATUS SocketReceiver::Recv(Message* msg, int* send_id, int timeout) {
309
310
311
  // queue_sem_ is a semaphore indicating how many elements in multiple
  // message queues.
  // When calling queue_sem_.Wait(), this Recv will be suspended until
312
313
314
315
316
  // queue_sem_ > 0 or specified timeout expires, decrease queue_sem_ by 1,
  // then start to fetch a message.
  if (!queue_sem_.TimedWait(timeout)) {
    return QUEUE_EMPTY;
  }
317
  for (;;) {
318
319
    for (; mq_iter_ != msg_queue_.end(); ++mq_iter_) {
      STATUS code = mq_iter_->second->Remove(msg, false);
320
321
322
      if (code == QUEUE_EMPTY) {
        continue;  // jump to the next queue
      } else {
323
324
        *send_id = mq_iter_->first;
        ++mq_iter_;
325
326
327
        return code;
      }
    }
328
    mq_iter_ = msg_queue_.begin();
329
  }
330
331
332
  LOG(ERROR)
      << "Failed to remove message from queue due to unexpected queue status.";
  return QUEUE_CLOSE;
333
334
}

335
STATUS SocketReceiver::RecvFrom(Message* msg, int send_id, int timeout) {
336
  // Get message from specified message queue
337
338
339
  if (!queue_sem_.TimedWait(timeout)) {
    return QUEUE_EMPTY;
  }
340
341
342
343
344
345
346
347
348
  STATUS code = msg_queue_[send_id]->Remove(msg);
  return code;
}

void SocketReceiver::Finalize() {
  // Send a signal to tell the message queue to finish its job
  for (auto& mq : msg_queue_) {
    // wait until queue is empty
    while (mq.second->Empty() == false) {
349
      std::this_thread::sleep_for(std::chrono::seconds(1));
350
    }
351
    mq.second->SignalFinished(mq.first);
352
353
354
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
355
    thread->join();
356
357
  }
  // Clear all sockets
358
359
360
361
  for (auto& group_sockets : sockets_) {
    for (auto& socket : group_sockets) {
      socket.second->Close();
    }
362
  }
Chao Ma's avatar
Chao Ma committed
363
364
  server_socket_->Close();
  delete server_socket_;
365
366
}

367
368
369
370
371
372
int64_t RecvDataSize(TCPSocket* socket) {
  int64_t received_bytes = 0;
  int64_t data_size = 0;
  while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) {
    int64_t max_len = sizeof(int64_t) - received_bytes;
    int64_t tmp = socket->Receive(
373
        reinterpret_cast<char*>(&data_size) + received_bytes, max_len);
374
375
376
377
378
379
    if (tmp == -1) {
      if (received_bytes > 0) {
        // We want to finish reading full data_size
        continue;
      }
      return -1;
380
    }
381
382
383
384
385
    received_bytes += tmp;
  }
  return data_size;
}

386
387
388
void RecvData(
    TCPSocket* socket, char* buffer, const int64_t& data_size,
    int64_t* received_bytes) {
389
390
391
392
393
  while (*received_bytes < data_size) {
    int64_t max_len = data_size - *received_bytes;
    int64_t tmp = socket->Receive(buffer + *received_bytes, max_len);
    if (tmp == -1) {
      // Socket not ready, no more data to read
394
      return;
395
396
397
398
399
400
    }
    *received_bytes += tmp;
  }
}

void SocketReceiver::RecvLoop(
401
402
403
404
405
406
407
    std::unordered_map<
        int /* Sender (virtual) ID */, std::shared_ptr<TCPSocket>>
        sockets,
    std::unordered_map<
        int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>
        queues,
    runtime::Semaphore* queue_sem) {
408
409
410
  std::unordered_map<int, std::unique_ptr<RecvContext>> recv_contexts;
  SocketPool socket_pool;
  for (auto& socket : sockets) {
411
    auto& sender_id = socket.first;
412
413
414
415
416
417
418
419
420
421
422
423
424
    socket_pool.AddSocket(socket.second, sender_id);
    recv_contexts[sender_id] = std::unique_ptr<RecvContext>(new RecvContext());
  }

  // Main loop to receive messages
  for (;;) {
    int sender_id;
    // Get active socket using epoll
    std::shared_ptr<TCPSocket> socket = socket_pool.GetActiveSocket(&sender_id);
    if (queues[sender_id]->EmptyAndNoMoreAdd()) {
      // This sender has already stopped
      if (socket_pool.RemoveSocket(socket) == 0) {
        return;
425
      }
426
427
428
429
430
      continue;
    }

    // Nonblocking socket might be interrupted at any point. So we need to
    // store the partially received data
431
432
433
    std::unique_ptr<RecvContext>& ctx = recv_contexts[sender_id];
    int64_t& data_size = ctx->data_size;
    int64_t& received_bytes = ctx->received_bytes;
434
435
436
437
438
439
440
441
    char*& buffer = ctx->buffer;

    if (data_size == -1) {
      // This is a new message, so receive the data size first
      data_size = RecvDataSize(socket.get());
      if (data_size > 0) {
        try {
          buffer = new char[data_size];
442
        } catch (const std::bad_alloc&) {
443
444
445
446
447
448
449
450
451
          LOG(FATAL) << "Cannot allocate enough memory for message, "
                     << "(message size: " << data_size << ")";
        }
        received_bytes = 0;
      } else if (data_size == 0) {
        // Received stop signal
        if (socket_pool.RemoveSocket(socket) == 0) {
          return;
        }
452
      }
453
454
455
456
457
    }

    RecvData(socket.get(), buffer, data_size, &received_bytes);
    if (received_bytes >= data_size) {
      // Full data received, create Message and push to queue
458
459
460
461
      Message msg;
      msg.data = buffer;
      msg.size = data_size;
      msg.deallocator = DefaultMessageDeleter;
462
463
464
465
466
467
468
      queues[sender_id]->Add(msg);

      // Reset recv context
      data_size = -1;

      // Signal queue semaphore
      queue_sem->Post();
469
470
471
472
473
474
    }
  }
}

}  // namespace network
}  // namespace dgl