socket_communicator.h 7.91 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file communicator.h
 * @brief SocketCommunicator for DGL distributed training.
5
 */
6
7
#ifndef DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
#define DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
8

9
#include <memory>
10
#include <string>
11
#include <thread>
12
#include <unordered_map>
13
#include <vector>
14

15
#include "../../runtime/semaphore_wrapper.h"
16
#include "../rpc_msg.h"
17
#include "common.h"
18
19
20
21
22
23
24
#include "communicator.h"
#include "msg_queue.h"
#include "tcp_socket.h"

namespace dgl {
namespace network {

25
26
static constexpr int kTimeOut =
    10 * 60;  // 10 minutes (in seconds) for socket timeout
27
static constexpr int kMaxConnection = 1024;  // maximal connection: 1024
28

29
/**
30
 * @breif Networking address
31
 */
32
33
34
struct IPAddr {
  std::string ip;
  int port;
35
36
};

37
/**
38
 * @brief SocketSender for DGL distributed training.
39
 *
40
 * SocketSender is the communicator implemented by tcp socket.
41
42
 */
class SocketSender : public Sender {
43
 public:
44
  /**
45
46
47
   * @brief Sender constructor
   * @param queue_size size of message queue
   * @param max_thread_count size of thread pool. 0 for no limit
48
   */
49
  SocketSender(int64_t queue_size, int max_thread_count)
50
      : Sender(queue_size, max_thread_count) {}
51

52
  /**
53
   * @brief Connect to a receiver.
54
55
56
57
58
59
   *
   * When there are multiple receivers to be connected, application will call
   * `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
   * sure that either all the connections are successfully established or some
   * of them fail.
   *
60
61
62
   * @param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
   * @param recv_id receiver's ID
   * @return True for success and False for fail
63
   *
64
   * The function is *not* thread-safe; only one thread can invoke this API.
65
   */
66
  bool ConnectReceiver(const std::string& addr, int recv_id);
67

68
  /**
69
   * @brief Finalize the action to connect to receivers. Make sure that either
70
   *        all connections are successfully established or connection fails.
71
   * @return True for success and False for fail
72
   *
73
74
   * The function is *not* thread-safe; only one thread can invoke this API.
   */
75
  bool ConnectReceiverFinalize(const int max_try_times);
76

77
  /**
78
79
80
   * @brief Send RPCMessage to specified Receiver.
   * @param msg data message
   * @param recv_id receiver's ID
81
   */
82
  void Send(const rpc::RPCMessage& msg, int recv_id);
83

84
  /**
85
   * @brief Finalize TPSender
86
   */
87
  void Finalize();
88

89
  /**
90
   * @brief Send data to specified Receiver. Actually pushing message to message
91
   * queue.
92
93
94
   * @param msg data message.
   * @param recv_id receiver's ID.
   * @return Status code.
95
   *
96
97
98
99
100
101
   * (1) The send is non-blocking. There is no guarantee that the message has
   * been physically sent out when the function returns. (2) The communicator
   * will assume the responsibility of the given message. (3) The API is
   * multi-thread safe. (4) Messages sent to the same receiver are guaranteed to
   * be received in the same order. There is no guarantee for messages sent to
   * different receivers.
102
   */
103
  STATUS Send(Message msg, int recv_id) override;
104

105
 private:
106
  /**
107
   * @brief socket for each connection of receiver
108
109
110
111
   */
  std::vector<
      std::unordered_map<int /* receiver ID */, std::shared_ptr<TCPSocket>>>
      sockets_;
112

113
  /**
114
   * @brief receivers' address
115
   */
116
  std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;
117

118
  /**
119
   * @brief message queue for each thread
120
   */
121
  std::vector<std::shared_ptr<MessageQueue>> msg_queue_;
122

123
  /**
124
   * @brief Independent thread
125
   */
126
  std::vector<std::shared_ptr<std::thread>> threads_;
127

128
  /**
129
130
131
   * @brief Send-loop for each thread
   * @param sockets TCPSockets for current thread
   * @param queue message_queue for current thread
132
   *
133
134
135
   * Note that, the SendLoop will finish its loop-job and exit thread
   * when the main thread invokes Signal() API on the message queue.
   */
136
  static void SendLoop(
137
138
139
140
      std::unordered_map<
          int /* Receiver (virtual) ID */, std::shared_ptr<TCPSocket>>
          sockets,
      std::shared_ptr<MessageQueue> queue);
141
142
};

143
/**
144
 * @brief SocketReceiver for DGL distributed training.
145
 *
146
 * SocketReceiver is the communicator implemented by tcp socket.
147
148
149
 */
class SocketReceiver : public Receiver {
 public:
150
  /**
151
152
153
   * @brief Receiver constructor
   * @param queue_size size of message queue.
   * @param max_thread_count size of thread pool. 0 for no limit
154
   */
155
  SocketReceiver(int64_t queue_size, int max_thread_count)
156
      : Receiver(queue_size, max_thread_count) {}
157

158
  /**
159
160
161
162
163
   * @brief Wait for all the Senders to connect
   * @param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0'
   * @param num_sender total number of Senders
   * @param blocking whether wait blockingly
   * @return True for success and False for fail
164
165
   *
   * Wait() is not thread-safe and only one thread can invoke this API.
166
   */
167
  bool Wait(const std::string& addr, int num_sender, bool blocking = true);
168

169
  /**
170
171
172
   * @brief Recv RPCMessage from Sender. Actually removing data from queue.
   * @param msg pointer of RPCmessage
   * @param timeout The timeout value in milliseconds. If zero, wait
173
   * indefinitely.
174
   * @return RPCStatus: kRPCSuccess or kRPCTimeOut.
175
   */
176
  rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout);
177

178
  /**
179
180
181
182
   * @brief Recv data from Sender. Actually removing data from msg_queue.
   * @param msg pointer of data message
   * @param send_id which sender current msg comes from
   * @param timeout The timeout value in milliseconds. If zero, wait
183
   * indefinitely.
184
   * @return Status code
185
   *
186
   * (1) The Recv() API is thread-safe.
187
188
   * (2) Memory allocated by communicator but will not own it after the function
   * returns.
189
   */
190
  STATUS Recv(Message* msg, int* send_id, int timeout = 0) override;
191

192
  /**
193
   * @brief Recv data from a specified Sender. Actually removing data from
194
   * msg_queue.
195
196
197
   * @param msg pointer of data message.
   * @param send_id sender's ID
   * @param timeout The timeout value in milliseconds. If zero, wait
198
   * indefinitely.
199
   * @return Status code
200
   *
201
   * (1) The RecvFrom() API is thread-safe.
202
203
   * (2) Memory allocated by communicator but will not own it after the function
   * returns.
204
   */
205
  STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) override;
206

207
  /**
208
   * @brief Finalize SocketReceiver
209
210
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
211
   */
212
  void Finalize();
213

214
 private:
215
216
217
  struct RecvContext {
    int64_t data_size = -1;
    int64_t received_bytes = 0;
218
    char* buffer = nullptr;
219
  };
220
  /**
221
   * @brief number of sender
222
223
224
   */
  int num_sender_;

225
  /**
226
   * @brief server socket for listening connections
227
   */
228
  TCPSocket* server_socket_;
229

230
  /**
231
   * @brief socket for each client connections
232
233
234
235
   */
  std::vector<std::unordered_map<
      int /* Sender (virutal) ID */, std::shared_ptr<TCPSocket>>>
      sockets_;
236

237
  /**
238
   * @brief Message queue for each socket connection
239
240
241
242
   */
  std::unordered_map<
      int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>
      msg_queue_;
243
  std::unordered_map<int, std::shared_ptr<MessageQueue>>::iterator mq_iter_;
244

245
  /**
246
   * @brief Independent thead
247
   */
248
  std::vector<std::shared_ptr<std::thread>> threads_;
249

250
  /**
251
   * @brief queue_sem_ semphore to indicate number of messages in multiple
252
253
254
255
   * message queues to prevent busy wait of Recv
   */
  runtime::Semaphore queue_sem_;

256
  /**
257
258
259
   * @brief Recv-loop for each thread
   * @param sockets client sockets of current thread
   * @param queue message queues of current thread
260
261
262
   *
   * Note that, the RecvLoop will finish its loop-job and exit thread
   * when the main thread invokes Signal() API on the message queue.
263
   */
264
  static void RecvLoop(
265
266
267
268
269
270
271
      std::unordered_map<
          int /* Sender (virtual) ID */, std::shared_ptr<TCPSocket>>
          sockets,
      std::unordered_map<
          int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>
          queues,
      runtime::Semaphore* queue_sem);
272
273
274
275
276
};

}  // namespace network
}  // namespace dgl

277
#endif  // DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_