socket_communicator.cc 9.21 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2019 by Contributors
 * \file communicator.cc
 * \brief SocketCommunicator for DGL distributed training.
 */
#include <dmlc/logging.h>

8
9
10
11
#include <string.h>
#include <stdlib.h>
#include <time.h>

12
13
14
15
16
17
18
19
20
21
22
23
24
#include "socket_communicator.h"
#include "../../c_api_common.h"

#ifdef _WIN32
#include <windows.h>
#else   // !_WIN32
#include <unistd.h>
#endif  // _WIN32

namespace dgl {
namespace network {


25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
/////////////////////////////////////// SocketSender ///////////////////////////////////////////


void SocketSender::AddReceiver(const char* addr, int recv_id) {
  CHECK_NOTNULL(addr);
  if (recv_id < 0) {
    LOG(FATAL) << "recv_id cannot be a negative number.";
  }
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
  if (substring[0] != "socket:" || substring.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  IPAddr address;
  address.ip = ip_and_port[0];
  address.port = std::stoi(ip_and_port[1]);
  receiver_addrs_[recv_id] = address;
  msg_queue_[recv_id] =  std::make_shared<MessageQueue>(queue_size_);
54
55
}

56
57
bool SocketSender::Connect() {
  // Create N sockets for Receiver
58
  for (const auto& r : receiver_addrs_) {
59
    int ID = r.first;
60
61
    sockets_[ID] = std::make_shared<TCPSocket>();
    TCPSocket* client_socket = sockets_[ID].get();
62
63
    bool bo = false;
    int try_count = 0;
64
65
    const char* ip = r.second.ip.c_str();
    int port = r.second.port;
66
    while (bo == false && try_count < kMaxTryCount) {
67
      if (client_socket->Connect(ip, port)) {
68
69
70
71
72
73
        LOG(INFO) << "Connected to Receiver: " << ip << ":" << port;
        bo = true;
      } else {
        LOG(ERROR) << "Cannot connect to Receiver: " << ip << ":" << port
                   << ", try again ...";
        try_count++;
74
#ifdef _WIN32
75
        Sleep(1);
76
#else   // !_WIN32
77
        sleep(1);
78
#endif  // _WIN32
79
80
81
82
83
      }
    }
    if (bo == false) {
      return bo;
    }
84
85
86
87
88
    // Create a new thread for this socket connection
    threads_[ID] = std::make_shared<std::thread>(
      SendLoop,
      client_socket,
      msg_queue_[ID].get());
89
90
91
92
  }
  return true;
}

93
94
95
96
97
98
99
STATUS SocketSender::Send(Message msg, int recv_id) {
  CHECK_NOTNULL(msg.data);
  CHECK_GT(msg.size, 0);
  CHECK_GE(recv_id, 0);
  // Add data message to message queue
  STATUS code = msg_queue_[recv_id]->Add(msg);
  return code;
100
101
102
}

void SocketSender::Finalize() {
103
104
105
106
107
108
109
110
111
  // Send a signal to tell the msg_queue to finish its job
  for (auto& mq : msg_queue_) {
    // wait until queue is empty
    while (mq.second->Empty() == false) {
#ifdef _WIN32
        // just loop
#else   // !_WIN32
        usleep(1000);
#endif  // _WIN32
112
    }
113
114
115
116
117
118
119
120
121
122
    int ID = mq.first;
    mq.second->SignalFinished(ID);
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
    thread.second->join();
  }
  // Clear all sockets
  for (auto& socket : sockets_) {
    socket.second->Close();
123
  }
124
125
}

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
void SocketSender::SendLoop(TCPSocket* socket, MessageQueue* queue) {
  CHECK_NOTNULL(socket);
  CHECK_NOTNULL(queue);
  bool exit = false;
  while (!exit) {
    Message msg;
    STATUS code = queue->Remove(&msg);
    if (code == QUEUE_CLOSE) {
      msg.size = 0;  // send an end-signal to receiver
      exit = true;
    }
    // First send the size
    // If exit == true, we will send zero size to reciever
    int64_t sent_bytes = 0;
    while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
      int64_t max_len = sizeof(int64_t) - sent_bytes;
      int64_t tmp = socket->Send(
        reinterpret_cast<char*>(&msg.size)+sent_bytes,
        max_len);
      CHECK_NE(tmp, -1);
      sent_bytes += tmp;
    }
    // Then send the data
    sent_bytes = 0;
    while (sent_bytes < msg.size) {
      int64_t max_len = msg.size - sent_bytes;
      int64_t tmp = socket->Send(msg.data+sent_bytes, max_len);
      CHECK_NE(tmp, -1);
      sent_bytes += tmp;
    }
    // delete msg
    if (msg.deallocator != nullptr) {
      msg.deallocator(&msg);
    }
  }
161
162
}

163
/////////////////////////////////////// SocketReceiver ///////////////////////////////////////////
164

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
bool SocketReceiver::Wait(const char* addr, int num_sender) {
  CHECK_NOTNULL(addr);
  CHECK_GT(num_sender, 0);
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
  if (substring[0] != "socket:" || substring.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  std::string ip = ip_and_port[0];
  int port = stoi(ip_and_port[1]);
  // Initialize message queue for each connection
187
  num_sender_ = num_sender;
188
189
190
191
192
193
  for (int i = 0; i < num_sender_; ++i) {
    msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
  }
  // Initialize socket and socket-thread
  server_socket_ = new TCPSocket();
  server_socket_->SetTimeout(kTimeOut * 60 * 1000);  // millsec
194
  // Bind socket
195
  if (server_socket_->Bind(ip.c_str(), port) == false) {
196
    LOG(FATAL) << "Cannot bind to " << ip << ":" << port;
197
198
199
  }
  LOG(INFO) << "Bind to " << ip << ":" << port;
  // Listen
200
  if (server_socket_->Listen(kMaxConnection) == false) {
201
    LOG(FATAL) << "Cannot listen on " << ip << ":" << port;
202
203
204
205
206
  }
  LOG(INFO) << "Listen on " << ip << ":" << port << ", wait sender connect ...";
  // Accept all sender sockets
  std::string accept_ip;
  int accept_port;
207
208
209
210
  for (int i = 0; i < num_sender_; ++i) {
    sockets_[i] = std::make_shared<TCPSocket>();
    if (server_socket_->Accept(sockets_[i].get(), &accept_ip, &accept_port) == false) {
      LOG(WARNING) << "Error on accept socket.";
211
212
      return false;
    }
213
    // create new thread for each socket
214
215
216
217
    threads_[i] = std::make_shared<std::thread>(
      RecvLoop,
      sockets_[i].get(),
      msg_queue_[i].get());
218
219
220
221
222
223
    LOG(INFO) << "Accept new sender: " << accept_ip << ":" << accept_port;
  }

  return true;
}

224
225
STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
  // loop until get a message
226
  for (;;) {
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    for (auto& mq : msg_queue_) {
      *send_id = mq.first;
      // We use non-block remove here
      STATUS code = msg_queue_[*send_id]->Remove(msg, false);
      if (code == QUEUE_EMPTY) {
        continue;  // jump to the next queue
      } else {
        return code;
      }
    }
  }
}

STATUS SocketReceiver::RecvFrom(Message* msg, int send_id) {
  // Get message from specified message queue
  STATUS code = msg_queue_[send_id]->Remove(msg);
  return code;
}

void SocketReceiver::Finalize() {
  // Send a signal to tell the message queue to finish its job
  for (auto& mq : msg_queue_) {
    // wait until queue is empty
    while (mq.second->Empty() == false) {
#ifdef _WIN32
        // just loop
#else   // !_WIN32
        usleep(1000);
#endif  // _WIN32
    }
    int ID = mq.first;
    mq.second->SignalFinished(ID);
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
    thread.second->join();
  }
  // Clear all sockets
  for (auto& socket : sockets_) {
    socket.second->Close();
  }
}

void SocketReceiver::RecvLoop(TCPSocket* socket, MessageQueue* queue) {
  CHECK_NOTNULL(socket);
  CHECK_NOTNULL(queue);
  for (;;) {
    // If main thread had finished its job
    if (queue->EmptyAndNoMoreAdd()) {
      return;  // exit loop thread
    }
278
279
280
    // First recv the size
    int64_t received_bytes = 0;
    int64_t data_size = 0;
Da Zheng's avatar
Da Zheng committed
281
    while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) {
282
283
284
285
      int64_t max_len = sizeof(int64_t) - received_bytes;
      int64_t tmp = socket->Receive(
        reinterpret_cast<char*>(&data_size)+received_bytes,
        max_len);
286
      CHECK_NE(tmp, -1);
287
288
      received_bytes += tmp;
    }
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    if (data_size < 0) {
      LOG(FATAL) << "Recv data error (data_size: " << data_size << ")";
    } else if (data_size == 0) {
      // This is an end-signal sent by client
      return;
    } else {
      char* buffer = nullptr;
      try {
        buffer = new char[data_size];
      } catch(const std::bad_alloc&) {
        LOG(FATAL) << "Cannot allocate enough memory for message, "
                   << "(message size: " << data_size << ")";
      }
      received_bytes = 0;
      while (received_bytes < data_size) {
        int64_t max_len = data_size - received_bytes;
        int64_t tmp = socket->Receive(buffer+received_bytes, max_len);
        CHECK_NE(tmp, -1);
        received_bytes += tmp;
      }
      Message msg;
      msg.data = buffer;
      msg.size = data_size;
      msg.deallocator = DefaultMessageDeleter;
      queue->Add(msg);
314
315
316
317
318
319
    }
  }
}

}  // namespace network
}  // namespace dgl