"src/vscode:/vscode.git/clone" did not exist on "23904d54d0beeec4f112ae19cd00e73722f4c113"
socket_communicator.cc 8.96 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2019 by Contributors
 * \file communicator.cc
 * \brief SocketCommunicator for DGL distributed training.
 */
#include <dmlc/logging.h>

8
9
10
#include <string.h>
#include <stdlib.h>
#include <time.h>
11
#include <memory>
12

13
14
15
16
17
18
19
20
21
22
23
24
25
#include "socket_communicator.h"
#include "../../c_api_common.h"

#ifdef _WIN32
#include <windows.h>
#else   // !_WIN32
#include <unistd.h>
#endif  // _WIN32

namespace dgl {
namespace network {


26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
/////////////////////////////////////// SocketSender ///////////////////////////////////////////


void SocketSender::AddReceiver(const char* addr, int recv_id) {
  CHECK_NOTNULL(addr);
  if (recv_id < 0) {
    LOG(FATAL) << "recv_id cannot be a negative number.";
  }
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
  if (substring[0] != "socket:" || substring.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  IPAddr address;
  address.ip = ip_and_port[0];
  address.port = std::stoi(ip_and_port[1]);
  receiver_addrs_[recv_id] = address;
  msg_queue_[recv_id] =  std::make_shared<MessageQueue>(queue_size_);
55
56
}

57
58
bool SocketSender::Connect() {
  // Create N sockets for Receiver
59
  for (const auto& r : receiver_addrs_) {
60
    int ID = r.first;
61
62
    sockets_[ID] = std::make_shared<TCPSocket>();
    TCPSocket* client_socket = sockets_[ID].get();
63
64
    bool bo = false;
    int try_count = 0;
65
66
    const char* ip = r.second.ip.c_str();
    int port = r.second.port;
67
    while (bo == false && try_count < kMaxTryCount) {
68
      if (client_socket->Connect(ip, port)) {
69
70
        bo = true;
      } else {
71
72
73
        if (try_count % 10 == 0 && try_count != 0) {
          LOG(INFO) << "Try to connect to: " << ip << ":" << port;
        }
74
        try_count++;
75
#ifdef _WIN32
76
        Sleep(5);
77
#else   // !_WIN32
78
        sleep(5);
79
#endif  // _WIN32
80
81
82
83
84
      }
    }
    if (bo == false) {
      return bo;
    }
85
86
87
88
89
    // Create a new thread for this socket connection
    threads_[ID] = std::make_shared<std::thread>(
      SendLoop,
      client_socket,
      msg_queue_[ID].get());
90
91
92
93
  }
  return true;
}

94
95
96
97
98
99
100
STATUS SocketSender::Send(Message msg, int recv_id) {
  CHECK_NOTNULL(msg.data);
  CHECK_GT(msg.size, 0);
  CHECK_GE(recv_id, 0);
  // Add data message to message queue
  STATUS code = msg_queue_[recv_id]->Add(msg);
  return code;
101
102
103
}

void SocketSender::Finalize() {
104
105
106
107
108
109
110
111
112
  // Send a signal to tell the msg_queue to finish its job
  for (auto& mq : msg_queue_) {
    // wait until queue is empty
    while (mq.second->Empty() == false) {
#ifdef _WIN32
        // just loop
#else   // !_WIN32
        usleep(1000);
#endif  // _WIN32
113
    }
114
115
116
117
118
119
120
121
122
123
    int ID = mq.first;
    mq.second->SignalFinished(ID);
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
    thread.second->join();
  }
  // Clear all sockets
  for (auto& socket : sockets_) {
    socket.second->Close();
124
  }
125
126
}

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
void SocketSender::SendLoop(TCPSocket* socket, MessageQueue* queue) {
  CHECK_NOTNULL(socket);
  CHECK_NOTNULL(queue);
  bool exit = false;
  while (!exit) {
    Message msg;
    STATUS code = queue->Remove(&msg);
    if (code == QUEUE_CLOSE) {
      msg.size = 0;  // send an end-signal to receiver
      exit = true;
    }
    // First send the size
    // If exit == true, we will send zero size to reciever
    int64_t sent_bytes = 0;
    while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
      int64_t max_len = sizeof(int64_t) - sent_bytes;
      int64_t tmp = socket->Send(
        reinterpret_cast<char*>(&msg.size)+sent_bytes,
        max_len);
      CHECK_NE(tmp, -1);
      sent_bytes += tmp;
    }
    // Then send the data
    sent_bytes = 0;
    while (sent_bytes < msg.size) {
      int64_t max_len = msg.size - sent_bytes;
      int64_t tmp = socket->Send(msg.data+sent_bytes, max_len);
      CHECK_NE(tmp, -1);
      sent_bytes += tmp;
    }
    // delete msg
    if (msg.deallocator != nullptr) {
      msg.deallocator(&msg);
    }
  }
162
163
}

164
/////////////////////////////////////// SocketReceiver ///////////////////////////////////////////
165

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
bool SocketReceiver::Wait(const char* addr, int num_sender) {
  CHECK_NOTNULL(addr);
  CHECK_GT(num_sender, 0);
  std::vector<std::string> substring;
  std::vector<std::string> ip_and_port;
  SplitStringUsing(addr, "//", &substring);
  // Check address format
  if (substring[0] != "socket:" || substring.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  // Get IP and port
  SplitStringUsing(substring[1], ":", &ip_and_port);
  if (ip_and_port.size() != 2) {
    LOG(FATAL) << "Incorrect address format:" << addr
               << " Please provide right address format, "
               << "e.g, 'socket://127.0.0.1:50051'. ";
  }
  std::string ip = ip_and_port[0];
  int port = stoi(ip_and_port[1]);
  // Initialize message queue for each connection
188
  num_sender_ = num_sender;
189
190
191
192
193
  for (int i = 0; i < num_sender_; ++i) {
    msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
  }
  // Initialize socket and socket-thread
  server_socket_ = new TCPSocket();
194
  server_socket_->SetTimeout(kTimeOut);  // seconds
195
  // Bind socket
196
  if (server_socket_->Bind(ip.c_str(), port) == false) {
197
    LOG(FATAL) << "Cannot bind to " << ip << ":" << port;
198
199
  }
  // Listen
200
  if (server_socket_->Listen(kMaxConnection) == false) {
201
    LOG(FATAL) << "Cannot listen on " << ip << ":" << port;
202
203
204
205
  }
  // Accept all sender sockets
  std::string accept_ip;
  int accept_port;
206
207
208
209
  for (int i = 0; i < num_sender_; ++i) {
    sockets_[i] = std::make_shared<TCPSocket>();
    if (server_socket_->Accept(sockets_[i].get(), &accept_ip, &accept_port) == false) {
      LOG(WARNING) << "Error on accept socket.";
210
211
      return false;
    }
212
    // create new thread for each socket
213
214
215
216
    threads_[i] = std::make_shared<std::thread>(
      RecvLoop,
      sockets_[i].get(),
      msg_queue_[i].get());
217
218
219
220
221
  }

  return true;
}

222
223
STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
  // loop until get a message
224
  for (;;) {
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    for (auto& mq : msg_queue_) {
      *send_id = mq.first;
      // We use non-block remove here
      STATUS code = msg_queue_[*send_id]->Remove(msg, false);
      if (code == QUEUE_EMPTY) {
        continue;  // jump to the next queue
      } else {
        return code;
      }
    }
  }
}

STATUS SocketReceiver::RecvFrom(Message* msg, int send_id) {
  // Get message from specified message queue
  STATUS code = msg_queue_[send_id]->Remove(msg);
  return code;
}

void SocketReceiver::Finalize() {
  // Send a signal to tell the message queue to finish its job
  for (auto& mq : msg_queue_) {
    // wait until queue is empty
    while (mq.second->Empty() == false) {
#ifdef _WIN32
        // just loop
#else   // !_WIN32
        usleep(1000);
#endif  // _WIN32
    }
    int ID = mq.first;
    mq.second->SignalFinished(ID);
  }
  // Block main thread until all socket-threads finish their jobs
  for (auto& thread : threads_) {
    thread.second->join();
  }
  // Clear all sockets
  for (auto& socket : sockets_) {
    socket.second->Close();
  }
}

void SocketReceiver::RecvLoop(TCPSocket* socket, MessageQueue* queue) {
  CHECK_NOTNULL(socket);
  CHECK_NOTNULL(queue);
  for (;;) {
    // If main thread had finished its job
    if (queue->EmptyAndNoMoreAdd()) {
      return;  // exit loop thread
    }
276
277
278
    // First recv the size
    int64_t received_bytes = 0;
    int64_t data_size = 0;
Da Zheng's avatar
Da Zheng committed
279
    while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) {
280
281
282
283
      int64_t max_len = sizeof(int64_t) - received_bytes;
      int64_t tmp = socket->Receive(
        reinterpret_cast<char*>(&data_size)+received_bytes,
        max_len);
284
      CHECK_NE(tmp, -1);
285
286
      received_bytes += tmp;
    }
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    if (data_size < 0) {
      LOG(FATAL) << "Recv data error (data_size: " << data_size << ")";
    } else if (data_size == 0) {
      // This is an end-signal sent by client
      return;
    } else {
      char* buffer = nullptr;
      try {
        buffer = new char[data_size];
      } catch(const std::bad_alloc&) {
        LOG(FATAL) << "Cannot allocate enough memory for message, "
                   << "(message size: " << data_size << ")";
      }
      received_bytes = 0;
      while (received_bytes < data_size) {
        int64_t max_len = data_size - received_bytes;
        int64_t tmp = socket->Receive(buffer+received_bytes, max_len);
        CHECK_NE(tmp, -1);
        received_bytes += tmp;
      }
      Message msg;
      msg.data = buffer;
      msg.size = data_size;
      msg.deallocator = DefaultMessageDeleter;
      queue->Add(msg);
312
313
314
315
316
317
    }
  }
}

}  // namespace network
}  // namespace dgl