socket_communicator.h 8.3 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2019 by Contributors
 * \file communicator.h
 * \brief SocketCommunicator for DGL distributed training.
 */
6
7
#ifndef DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
#define DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
8

9
#include <memory>
10
#include <string>
11
#include <thread>
12
#include <unordered_map>
13
#include <vector>
14

15
#include "../../runtime/semaphore_wrapper.h"
16
#include "common.h"
17
18
19
20
21
22
23
#include "communicator.h"
#include "msg_queue.h"
#include "tcp_socket.h"

namespace dgl {
namespace network {

24
25
static constexpr int kTimeOut =
    10 * 60;  // 10 minutes (in seconds) for socket timeout
26
static constexpr int kMaxConnection = 1024;  // maximal connection: 1024
27
28

/*!
29
 * \breif Networking address
30
 */
31
32
33
struct IPAddr {
  std::string ip;
  int port;
34
35
36
};

/*!
37
 * \brief SocketSender for DGL distributed training.
38
 *
39
 * SocketSender is the communicator implemented by tcp socket.
40
41
 */
class SocketSender : public Sender {
42
43
 public:
  /*!
44
   * \brief Sender constructor
45
   * \param queue_size size of message queue
46
   * \param max_thread_count size of thread pool. 0 for no limit
47
   */
48
  SocketSender(int64_t queue_size, int max_thread_count)
49
      : Sender(queue_size, max_thread_count) {}
50
51

  /*!
52
   * \brief Connect to a receiver.
53
54
55
56
57
58
   *
   * When there are multiple receivers to be connected, application will call
   * `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
   * sure that either all the connections are successfully established or some
   * of them fail.
   *
59
60
61
   * \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
   * \param recv_id receiver's ID
   * \return True for success and False for fail
62
   *
63
   * The function is *not* thread-safe; only one thread can invoke this API.
64
   */
65
  bool ConnectReceiver(const std::string& addr, int recv_id) override;
66

67
  /*!
68
69
   * \brief Finalize the action to connect to receivers. Make sure that either
   *        all connections are successfully established or connection fails.
70
71
   * \return True for success and False for fail
   *
72
73
   * The function is *not* thread-safe; only one thread can invoke this API.
   */
74
  bool ConnectReceiverFinalize(const int max_try_times) override;
75
76
77

  /*!
   * \brief Send RPCMessage to specified Receiver.
78
   * \param msg data message
79
80
81
82
83
84
   * \param recv_id receiver's ID
   */
  void Send(const rpc::RPCMessage& msg, int recv_id) override;

  /*!
   * \brief Finalize TPSender
85
   */
86
87
88
89
90
  void Finalize() override;

  /*!
   * \brief Communicator type: 'socket'
   */
91
  const std::string& NetType() const override {
92
93
94
    static const std::string net_type = "socket";
    return net_type;
  }
95
96

  /*!
97
98
99
100
101
   * \brief Send data to specified Receiver. Actually pushing message to message
   * queue.
   * \param msg data message.
   * \param recv_id receiver's ID.
   * \return Status code.
102
   *
103
104
105
106
107
108
   * (1) The send is non-blocking. There is no guarantee that the message has
   * been physically sent out when the function returns. (2) The communicator
   * will assume the responsibility of the given message. (3) The API is
   * multi-thread safe. (4) Messages sent to the same receiver are guaranteed to
   * be received in the same order. There is no guarantee for messages sent to
   * different receivers.
109
   */
110
  STATUS Send(Message msg, int recv_id) override;
111

112
 private:
113
  /*!
114
   * \brief socket for each connection of receiver
115
116
117
118
   */
  std::vector<
      std::unordered_map<int /* receiver ID */, std::shared_ptr<TCPSocket>>>
      sockets_;
119

120
  /*!
121
   * \brief receivers' address
122
   */
123
  std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;
124

125
  /*!
126
   * \brief message queue for each thread
127
   */
128
  std::vector<std::shared_ptr<MessageQueue>> msg_queue_;
129
130

  /*!
131
   * \brief Independent thread
132
   */
133
  std::vector<std::shared_ptr<std::thread>> threads_;
134
135

  /*!
136
137
138
   * \brief Send-loop for each thread
   * \param sockets TCPSockets for current thread
   * \param queue message_queue for current thread
139
   *
140
141
142
   * Note that, the SendLoop will finish its loop-job and exit thread
   * when the main thread invokes Signal() API on the message queue.
   */
143
  static void SendLoop(
144
145
146
147
      std::unordered_map<
          int /* Receiver (virtual) ID */, std::shared_ptr<TCPSocket>>
          sockets,
      std::shared_ptr<MessageQueue> queue);
148
149
150
};

/*!
151
 * \brief SocketReceiver for DGL distributed training.
152
 *
153
 * SocketReceiver is the communicator implemented by tcp socket.
154
155
156
157
 */
class SocketReceiver : public Receiver {
 public:
  /*!
158
159
   * \brief Receiver constructor
   * \param queue_size size of message queue.
160
   * \param max_thread_count size of thread pool. 0 for no limit
161
   */
162
  SocketReceiver(int64_t queue_size, int max_thread_count)
163
      : Receiver(queue_size, max_thread_count) {}
164
165

  /*!
166
   * \brief Wait for all the Senders to connect
167
   * \param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0'
168
   * \param num_sender total number of Senders
169
   * \param blocking whether wait blockingly
170
171
172
   * \return True for success and False for fail
   *
   * Wait() is not thread-safe and only one thread can invoke this API.
173
   */
174
175
  bool Wait(
      const std::string& addr, int num_sender, bool blocking = true) override;
176
177
178
179

  /*!
   * \brief Recv RPCMessage from Sender. Actually removing data from queue.
   * \param msg pointer of RPCmessage
180
181
   * \param timeout The timeout value in milliseconds. If zero, wait
   * indefinitely.
182
   * \return RPCStatus: kRPCSuccess or kRPCTimeOut.
183
   */
184
  rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout) override;
185
186
187
188
189

  /*!
   * \brief Recv data from Sender. Actually removing data from msg_queue.
   * \param msg pointer of data message
   * \param send_id which sender current msg comes from
190
191
   * \param timeout The timeout value in milliseconds. If zero, wait
   * indefinitely.
192
193
   * \return Status code
   *
194
   * (1) The Recv() API is thread-safe.
195
196
   * (2) Memory allocated by communicator but will not own it after the function
   * returns.
197
   */
198
  STATUS Recv(Message* msg, int* send_id, int timeout = 0) override;
199
200

  /*!
201
202
203
   * \brief Recv data from a specified Sender. Actually removing data from
   * msg_queue.
   * \param msg pointer of data message.
204
   * \param send_id sender's ID
205
206
   * \param timeout The timeout value in milliseconds. If zero, wait
   * indefinitely.
207
208
   * \return Status code
   *
209
   * (1) The RecvFrom() API is thread-safe.
210
211
   * (2) Memory allocated by communicator but will not own it after the function
   * returns.
212
   */
213
  STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) override;
214

215
  /*!
216
217
218
   * \brief Finalize SocketReceiver
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
219
   */
220
  void Finalize() override;
221
222

  /*!
223
   * \brief Communicator type: 'socket'
224
   */
225
  const std::string& NetType() const override {
226
227
228
    static const std::string net_type = "socket";
    return net_type;
  }
229

230
 private:
231
232
233
  struct RecvContext {
    int64_t data_size = -1;
    int64_t received_bytes = 0;
234
    char* buffer = nullptr;
235
  };
236
237
238
239
240
241
  /*!
   * \brief number of sender
   */
  int num_sender_;

  /*!
242
   * \brief server socket for listening connections
243
   */
244
  TCPSocket* server_socket_;
245
246

  /*!
247
   * \brief socket for each client connections
248
249
250
251
   */
  std::vector<std::unordered_map<
      int /* Sender (virutal) ID */, std::shared_ptr<TCPSocket>>>
      sockets_;
252
253

  /*!
254
   * \brief Message queue for each socket connection
255
256
257
258
   */
  std::unordered_map<
      int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>
      msg_queue_;
259
  std::unordered_map<int, std::shared_ptr<MessageQueue>>::iterator mq_iter_;
260

261
  /*!
262
   * \brief Independent thead
263
   */
264
  std::vector<std::shared_ptr<std::thread>> threads_;
265

266
  /*!
267
268
269
270
271
272
273
274
275
   * \brief queue_sem_ semphore to indicate number of messages in multiple
   * message queues to prevent busy wait of Recv
   */
  runtime::Semaphore queue_sem_;

  /*!
   * \brief Recv-loop for each thread
   * \param sockets client sockets of current thread
   * \param queue message queues of current thread
276
277
278
   *
   * Note that, the RecvLoop will finish its loop-job and exit thread
   * when the main thread invokes Signal() API on the message queue.
279
   */
280
  static void RecvLoop(
281
282
283
284
285
286
287
      std::unordered_map<
          int /* Sender (virtual) ID */, std::shared_ptr<TCPSocket>>
          sockets,
      std::unordered_map<
          int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>
          queues,
      runtime::Semaphore* queue_sem);
288
289
290
291
292
};

}  // namespace network
}  // namespace dgl

293
#endif  // DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_