socket_communicator.cc 13.9 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2019 by Contributors
 * \file communicator.cc
 * \brief SocketCommunicator for DGL distributed training.
 */
#include <dmlc/logging.h>

8
9
10
#include <string.h>
#include <stdlib.h>
#include <time.h>
11
#include <memory>
12

13
14
#include "socket_communicator.h"
#include "../../c_api_common.h"
15
#include "socket_pool.h"
16
17
18
19
20
21
22
23
24
25
26

#ifdef _WIN32
#include <windows.h>
#else   // !_WIN32
#include <unistd.h>
#endif  // _WIN32

namespace dgl {
namespace network {


27
28
29
/////////////////////////////////////// SocketSender ///////////////////////////////////////////


30
bool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) {
31
32
33
34
35
36
37
  if (recv_id < 0) {
    LOG(FATAL) << "recv_id cannot be a negative number.";
  }
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
38
  if (substring[0] != "tcp:" || substring.size() != 2) {
39
40
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
41
               << "e.g, 'tcp://127.0.0.1:50051'. ";
42
43
44
45
46
47
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
48
               << "e.g, 'tcp://127.0.0.1:50051'. ";
49
50
51
52
53
  }
  IPAddr address;
  address.ip = ip_and_port[0];
  address.port = std::stoi(ip_and_port[1]);
  receiver_addrs_[recv_id] = address;
54
55

  return true;
56
57
}

58
bool SocketSender::ConnectReceiverFinalize(const int max_try_times) {
59
  // Create N sockets for Receiver
60
61
62
63
64
  int receiver_count = static_cast<int>(receiver_addrs_.size());
  if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) {
    max_thread_count_ = receiver_count;
  }
  sockets_.resize(max_thread_count_);
65
  for (const auto& r : receiver_addrs_) {
66
67
68
69
    int receiver_id = r.first;
    int thread_id = receiver_id % max_thread_count_;
    sockets_[thread_id][receiver_id] = std::make_shared<TCPSocket>();
    TCPSocket* client_socket = sockets_[thread_id][receiver_id].get();
70
71
    bool bo = false;
    int try_count = 0;
72
73
    const char* ip = r.second.ip.c_str();
    int port = r.second.port;
74
    while (bo == false && try_count < max_try_times) {
75
      if (client_socket->Connect(ip, port)) {
76
77
        bo = true;
      } else {
78
        if (try_count % 200 == 0 && try_count != 0) {
79
80
          // every 600 seconds show this message
          LOG(INFO) << "Trying to connect receiver: " << ip << ":" << port;
81
        }
82
        try_count++;
83
        std::this_thread::sleep_for(std::chrono::seconds(3));
84
85
86
87
88
      }
    }
    if (bo == false) {
      return bo;
    }
89
90
91
92
  }

  for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
    msg_queue_.push_back(std::make_shared<MessageQueue>(queue_size_));
93
    // Create a new thread for this socket connection
94
    threads_.push_back(std::make_shared<std::thread>(
95
      SendLoop,
96
97
      sockets_[thread_id],
      msg_queue_[thread_id]));
98
  }
99

100
101
102
  return true;
}

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
void SocketSender::Send(const rpc::RPCMessage& msg, int recv_id) {
  std::shared_ptr<std::string> zerocopy_blob(new std::string());
  StreamWithBuffer zc_write_strm(zerocopy_blob.get(), true);
  zc_write_strm.Write(msg);
  int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
  zerocopy_blob->append(reinterpret_cast<char*>(&nonempty_ndarray_count),
                        sizeof(int32_t));
  Message rpc_meta_msg;
  rpc_meta_msg.data = const_cast<char*>(zerocopy_blob->data());
  rpc_meta_msg.size = zerocopy_blob->size();
  rpc_meta_msg.deallocator = [zerocopy_blob](Message*) {};
  CHECK_EQ(Send(
    rpc_meta_msg, recv_id), ADD_SUCCESS);
  // send real ndarray data
  for (auto ptr : zc_write_strm.buffer_list()) {
    Message ndarray_data_msg;
    ndarray_data_msg.data = reinterpret_cast<char*>(ptr.data);
    if (ptr.size == 0) {
      LOG(FATAL) << "Cannot send a empty NDArray.";
    }
    ndarray_data_msg.size = ptr.size;
    NDArray tensor = ptr.tensor;
    ndarray_data_msg.deallocator = [tensor](Message*) {};
    CHECK_EQ(Send(
      ndarray_data_msg, recv_id), ADD_SUCCESS);
  }
}

131
132
133
134
STATUS SocketSender::Send(Message msg, int recv_id) {
  CHECK_NOTNULL(msg.data);
  CHECK_GT(msg.size, 0);
  CHECK_GE(recv_id, 0);
135
  msg.receiver_id = recv_id;
136
  // Add data message to message queue
137
  STATUS code = msg_queue_[recv_id % max_thread_count_]->Add(msg);
138
  return code;
139
140
141
}

void SocketSender::Finalize() {
142
  // Send a signal to tell the msg_queue to finish its job
143
  for (int i = 0; i < max_thread_count_; ++i) {
144
    // wait until queue is empty
145
146
    auto& mq = msg_queue_[i];
    while (mq->Empty() == false) {
147
      std::this_thread::sleep_for(std::chrono::seconds(1));
148
    }
149
150
151
    // All queues have only one producer, which is main thread, so
    // the producerID argument here should be zero.
    mq->SignalFinished(0);
152
153
154
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
155
    thread->join();
156
157
  }
  // Clear all sockets
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
  for (auto& group_sockets_ : sockets_) {
    for (auto &socket : group_sockets_) {
      socket.second->Close();
    }
  }
}

void SendCore(Message msg, TCPSocket* socket) {
  // First send the size
  // If exit == true, we will send zero size to reciever
  int64_t sent_bytes = 0;
  while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
    int64_t max_len = sizeof(int64_t) - sent_bytes;
    int64_t tmp = socket->Send(
      reinterpret_cast<char*>(&msg.size) + sent_bytes,
      max_len);
    CHECK_NE(tmp, -1);
    sent_bytes += tmp;
  }
  // Then send the data
  sent_bytes = 0;
  while (sent_bytes < msg.size) {
    int64_t max_len = msg.size - sent_bytes;
    int64_t tmp = socket->Send(msg.data+sent_bytes, max_len);
    CHECK_NE(tmp, -1);
    sent_bytes += tmp;
  }
  // delete msg
  if (msg.deallocator != nullptr) {
    msg.deallocator(&msg);
188
  }
189
190
}

191
192
193
194
void SocketSender::SendLoop(
  std::unordered_map<int, std::shared_ptr<TCPSocket>> sockets,
  std::shared_ptr<MessageQueue> queue) {
  for (;;) {
195
196
197
198
    Message msg;
    STATUS code = queue->Remove(&msg);
    if (code == QUEUE_CLOSE) {
      msg.size = 0;  // send an end-signal to receiver
199
200
201
202
      for (auto& socket : sockets) {
        SendCore(msg, socket.second.get());
      }
      break;
203
    }
204
    SendCore(msg, sockets[msg.receiver_id].get());
205
  }
206
207
}

208
/////////////////////////////////////// SocketReceiver ///////////////////////////////////////////
209
bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking) {
210
  CHECK_GT(num_sender, 0);
211
  CHECK_EQ(blocking, true);
212
213
214
215
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
216
  if (substring[0] != "tcp:" || substring.size() != 2) {
217
218
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
219
               << "e.g, 'tcp://127.0.0.1:50051'. ";
220
221
222
223
224
225
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
226
               << "e.g, 'tcp://127.0.0.1:50051'. ";
227
228
229
230
  }
  std::string ip = ip_and_port[0];
  int port = stoi(ip_and_port[1]);
  // Initialize message queue for each connection
231
  num_sender_ = num_sender;
232
233
234
#ifdef USE_EPOLL
  if (max_thread_count_ == 0 || max_thread_count_ > num_sender_) {
      max_thread_count_ = num_sender_;
235
  }
236
237
238
#else
  max_thread_count_ = num_sender_;
#endif
239
240
  // Initialize socket and socket-thread
  server_socket_ = new TCPSocket();
241
  // Bind socket
242
  if (server_socket_->Bind(ip.c_str(), port) == false) {
243
    LOG(FATAL) << "Cannot bind to " << ip << ":" << port;
244
  }
245

246
  // Listen
247
  if (server_socket_->Listen(kMaxConnection) == false) {
248
    LOG(FATAL) << "Cannot listen on " << ip << ":" << port;
249
250
251
252
  }
  // Accept all sender sockets
  std::string accept_ip;
  int accept_port;
253
  sockets_.resize(max_thread_count_);
254
  for (int i = 0; i < num_sender_; ++i) {
255
256
257
258
259
    int thread_id = i % max_thread_count_;
    auto socket = std::make_shared<TCPSocket>();
    sockets_[thread_id][i] = socket;
    msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
    if (server_socket_->Accept(socket.get(), &accept_ip, &accept_port) == false) {
260
      LOG(WARNING) << "Error on accept socket.";
261
262
      return false;
    }
263
264
265
266
  }
  mq_iter_ = msg_queue_.begin();

  for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
267
    // create new thread for each socket
268
    threads_.push_back(std::make_shared<std::thread>(
269
      RecvLoop,
270
271
272
      sockets_[thread_id],
      msg_queue_,
      &queue_sem_));
273
274
275
276
277
  }

  return true;
}

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
void SocketReceiver::Recv(rpc::RPCMessage* msg) {
  Message rpc_meta_msg;
  int send_id;
  CHECK_EQ(Recv(
    &rpc_meta_msg, &send_id), REMOVE_SUCCESS);
  char* count_ptr = rpc_meta_msg.data+rpc_meta_msg.size-sizeof(int32_t);
  int32_t nonempty_ndarray_count = *(reinterpret_cast<int32_t*>(count_ptr));
  // Recv real ndarray data
  std::vector<void*> buffer_list(nonempty_ndarray_count);
  for (int i = 0; i < nonempty_ndarray_count; ++i) {
    Message ndarray_data_msg;
    CHECK_EQ(RecvFrom(
        &ndarray_data_msg, send_id), REMOVE_SUCCESS);
    buffer_list[i] = ndarray_data_msg.data;
  }
  StreamWithBuffer zc_read_strm(rpc_meta_msg.data, rpc_meta_msg.size-sizeof(int32_t), buffer_list);
  zc_read_strm.Read(msg);
  rpc_meta_msg.deallocator(&rpc_meta_msg);
}

298
STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
299
300
301
302
303
  // queue_sem_ is a semaphore indicating how many elements in multiple
  // message queues.
  // When calling queue_sem_.Wait(), this Recv will be suspended until
  // queue_sem_ > 0, decrease queue_sem_ by 1, then start to fetch a message.
  queue_sem_.Wait();
304
  for (;;) {
305
306
    for (; mq_iter_ != msg_queue_.end(); ++mq_iter_) {
      STATUS code = mq_iter_->second->Remove(msg, false);
307
308
309
      if (code == QUEUE_EMPTY) {
        continue;  // jump to the next queue
      } else {
310
311
        *send_id = mq_iter_->first;
        ++mq_iter_;
312
313
314
        return code;
      }
    }
315
    mq_iter_ = msg_queue_.begin();
316
317
318
319
320
  }
}

STATUS SocketReceiver::RecvFrom(Message* msg, int send_id) {
  // Get message from specified message queue
321
  queue_sem_.Wait();
322
323
324
325
326
327
328
329
330
  STATUS code = msg_queue_[send_id]->Remove(msg);
  return code;
}

void SocketReceiver::Finalize() {
  // Send a signal to tell the message queue to finish its job
  for (auto& mq : msg_queue_) {
    // wait until queue is empty
    while (mq.second->Empty() == false) {
331
        std::this_thread::sleep_for(std::chrono::seconds(1));
332
    }
333
    mq.second->SignalFinished(mq.first);
334
335
336
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
337
    thread->join();
338
339
  }
  // Clear all sockets
340
341
342
343
  for (auto& group_sockets : sockets_) {
    for (auto& socket : group_sockets) {
      socket.second->Close();
    }
344
  }
Chao Ma's avatar
Chao Ma committed
345
346
  server_socket_->Close();
  delete server_socket_;
347
348
}

349
350
351
352
353
354
355
356
357
358
359
360
361
362
int64_t RecvDataSize(TCPSocket* socket) {
  int64_t received_bytes = 0;
  int64_t data_size = 0;
  while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) {
    int64_t max_len = sizeof(int64_t) - received_bytes;
    int64_t tmp = socket->Receive(
      reinterpret_cast<char*>(&data_size) + received_bytes,
      max_len);
    if (tmp == -1) {
      if (received_bytes > 0) {
        // We want to finish reading full data_size
        continue;
      }
      return -1;
363
    }
364
365
366
367
368
369
370
371
372
373
374
375
    received_bytes += tmp;
  }
  return data_size;
}

void RecvData(TCPSocket* socket, char* buffer, const int64_t &data_size,
  int64_t *received_bytes) {
  while (*received_bytes < data_size) {
    int64_t max_len = data_size - *received_bytes;
    int64_t tmp = socket->Receive(buffer + *received_bytes, max_len);
    if (tmp == -1) {
      // Socket not ready, no more data to read
376
      return;
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    }
    *received_bytes += tmp;
  }
}

void SocketReceiver::RecvLoop(
  std::unordered_map<int /* Sender (virtual) ID */,
    std::shared_ptr<TCPSocket>> sockets,
  std::unordered_map<int /* Sender (virtual) ID */,
    std::shared_ptr<MessageQueue>> queues,
  runtime::Semaphore *queue_sem) {
  std::unordered_map<int, std::unique_ptr<RecvContext>> recv_contexts;
  SocketPool socket_pool;
  for (auto& socket : sockets) {
    auto &sender_id = socket.first;
    socket_pool.AddSocket(socket.second, sender_id);
    recv_contexts[sender_id] = std::unique_ptr<RecvContext>(new RecvContext());
  }

  // Main loop to receive messages
  for (;;) {
    int sender_id;
    // Get active socket using epoll
    std::shared_ptr<TCPSocket> socket = socket_pool.GetActiveSocket(&sender_id);
    if (queues[sender_id]->EmptyAndNoMoreAdd()) {
      // This sender has already stopped
      if (socket_pool.RemoveSocket(socket) == 0) {
        return;
405
      }
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
      continue;
    }

    // Nonblocking socket might be interrupted at any point. So we need to
    // store the partially received data
    std::unique_ptr<RecvContext> &ctx = recv_contexts[sender_id];
    int64_t &data_size = ctx->data_size;
    int64_t &received_bytes = ctx->received_bytes;
    char*& buffer = ctx->buffer;

    if (data_size == -1) {
      // This is a new message, so receive the data size first
      data_size = RecvDataSize(socket.get());
      if (data_size > 0) {
        try {
          buffer = new char[data_size];
        } catch(const std::bad_alloc&) {
          LOG(FATAL) << "Cannot allocate enough memory for message, "
                     << "(message size: " << data_size << ")";
        }
        received_bytes = 0;
      } else if (data_size == 0) {
        // Received stop signal
        if (socket_pool.RemoveSocket(socket) == 0) {
          return;
        }
432
      }
433
434
435
436
437
    }

    RecvData(socket.get(), buffer, data_size, &received_bytes);
    if (received_bytes >= data_size) {
      // Full data received, create Message and push to queue
438
439
440
441
      Message msg;
      msg.data = buffer;
      msg.size = data_size;
      msg.deallocator = DefaultMessageDeleter;
442
443
444
445
446
447
448
      queues[sender_id]->Add(msg);

      // Reset recv context
      data_size = -1;

      // Signal queue semaphore
      queue_sem->Post();
449
450
451
452
453
454
    }
  }
}

}  // namespace network
}  // namespace dgl