socket_communicator.h 8.07 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2019 by Contributors
 * \file communicator.h
 * \brief SocketCommunicator for DGL distributed training.
 */
6
7
#ifndef DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
#define DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
8
9
10
11

#include <thread>
#include <vector>
#include <string>
12
#include <unordered_map>
13
#include <memory>
14

15
#include "../../runtime/semaphore_wrapper.h"
16
17
18
#include "communicator.h"
#include "msg_queue.h"
#include "tcp_socket.h"
19
#include "common.h"
20
21
22
23

namespace dgl {
namespace network {

24
static constexpr int kTimeOut = 10 * 60;     // 10 minutes (in seconds) for socket timeout
25
static constexpr int kMaxConnection = 1024;  // maximal connection: 1024
26
27

/*!
28
 * \breif Networking address
29
 */
30
31
32
struct IPAddr {
  std::string ip;
  int port;
33
34
35
};

/*!
36
 * \brief SocketSender for DGL distributed training.
37
 *
38
 * SocketSender is the communicator implemented by tcp socket.
39
40
 */
class SocketSender : public Sender {
41
42
 public:
  /*!
43
44
   * \brief Sender constructor
   * \param queue_size size of message queue 
45
   * \param max_thread_count size of thread pool. 0 for no limit
46
   */
47
48
  SocketSender(int64_t queue_size, int max_thread_count)
    : Sender(queue_size, max_thread_count) {}
49
50

  /*!
51
52
53
54
55
56
57
58
59
   * \brief Connect to a receiver.
   * 
   * When there are multiple receivers to be connected, application will call `ConnectReceiver`
   * for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are
   * successfully established or some of them fail.
   * 
   * \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
   * \param recv_id receiver's ID
   * \return True for success and False for fail
60
   *
61
   * The function is *not* thread-safe; only one thread can invoke this API.
62
   */
63
  bool ConnectReceiver(const std::string& addr, int recv_id) override;
64

65
  /*!
66
67
   * \brief Finalize the action to connect to receivers. Make sure that either
   *        all connections are successfully established or connection fails.
68
69
   * \return True for success and False for fail
   *
70
71
   * The function is *not* thread-safe; only one thread can invoke this API.
   */
72
  bool ConnectReceiverFinalize(const int max_try_times) override;
73
74
75
76
77
78
79
80
81
82

  /*!
   * \brief Send RPCMessage to specified Receiver.
   * \param msg data message 
   * \param recv_id receiver's ID
   */
  void Send(const rpc::RPCMessage& msg, int recv_id) override;

  /*!
   * \brief Finalize TPSender
83
   */
84
85
86
87
88
89
90
91
92
  void Finalize() override;

  /*!
   * \brief Communicator type: 'socket'
   */
  const std::string &NetType() const override {
    static const std::string net_type = "socket";
    return net_type;
  }
93
94

  /*!
95
96
   * \brief Send data to specified Receiver. Actually pushing message to message queue.
   * \param msg data message
97
   * \param recv_id receiver's ID
98
99
100
101
102
103
104
105
   * \return Status code
   *
   * (1) The send is non-blocking. There is no guarantee that the message has been 
   *     physically sent out when the function returns.
   * (2) The communicator will assume the responsibility of the given message.
   * (3) The API is multi-thread safe.
   * (4) Messages sent to the same receiver are guaranteed to be received in the same order. 
   *     There is no guarantee for messages sent to different receivers.
106
   */
107
  STATUS Send(Message msg, int recv_id) override;
108

109
 private:
110
  /*!
111
112
   * \brief socket for each connection of receiver
   */ 
113
114
  std::vector<std::unordered_map<int /* receiver ID */,
    std::shared_ptr<TCPSocket>>> sockets_;
115

116
  /*!
117
   * \brief receivers' address
118
   */ 
119
  std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;
120

121
  /*!
122
   * \brief message queue for each thread
123
   */ 
124
  std::vector<std::shared_ptr<MessageQueue>> msg_queue_;
125
126

  /*!
127
   * \brief Independent thread
128
   */ 
129
  std::vector<std::shared_ptr<std::thread>> threads_;
130
131

  /*!
132
133
134
   * \brief Send-loop for each thread
   * \param sockets TCPSockets for current thread
   * \param queue message_queue for current thread
135
136
137
138
   * 
   * Note that, the SendLoop will finish its loop-job and exit thread
   * when the main thread invokes Signal() API on the message queue.
   */
139
140
141
142
  static void SendLoop(
    std::unordered_map<int /* Receiver (virtual) ID */,
      std::shared_ptr<TCPSocket>> sockets,
    std::shared_ptr<MessageQueue> queue);
143
144
145
};

/*!
146
 * \brief SocketReceiver for DGL distributed training.
147
 *
148
 * SocketReceiver is the communicator implemented by tcp socket.
149
150
151
152
 */
class SocketReceiver : public Receiver {
 public:
  /*!
153
154
   * \brief Receiver constructor
   * \param queue_size size of message queue.
155
   * \param max_thread_count size of thread pool. 0 for no limit
156
   */
157
158
  SocketReceiver(int64_t queue_size, int max_thread_count)
    : Receiver(queue_size, max_thread_count) {}
159
160

  /*!
161
   * \brief Wait for all the Senders to connect
162
   * \param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0'
163
   * \param num_sender total number of Senders
164
   * \param blocking whether wait blockingly
165
166
167
   * \return True for success and False for fail
   *
   * Wait() is not thread-safe and only one thread can invoke this API.
168
   */
169
170
171
172
173
174
175
176
  bool Wait(const std::string &addr, int num_sender,
            bool blocking = true) override;

  /*!
   * \brief Recv RPCMessage from Sender. Actually removing data from queue.
   * \param msg pointer of RPCmessage
   */
  void Recv(rpc::RPCMessage* msg) override;
177
178
179
180
181
182
183
184
185
186
187

  /*!
   * \brief Recv data from Sender. Actually removing data from msg_queue.
   * \param msg pointer of data message
   * \param send_id which sender current msg comes from
   * \return Status code
   *
   * (1) The Recv() API is blocking, which will not 
   *     return until getting data from message queue.
   * (2) The Recv() API is thread-safe.
   * (3) Memory allocated by communicator but will not own it after the function returns.
188
   */
189
  STATUS Recv(Message* msg, int* send_id) override;
190
191
192
193
194
195
196
197
198
199
200
201

  /*!
   * \brief Recv data from a specified Sender. Actually removing data from msg_queue.
   * \param msg pointer of data message
   * \param send_id sender's ID
   * \return Status code
   *
   * (1) The RecvFrom() API is blocking, which will not 
   *     return until getting data from message queue.
   * (2) The RecvFrom() API is thread-safe.
   * (3) Memory allocated by communicator but will not own it after the function returns.
   */
202
  STATUS RecvFrom(Message* msg, int send_id) override;
203

204
  /*!
205
206
207
   * \brief Finalize SocketReceiver
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
208
   */
209
  void Finalize() override;
210
211

  /*!
212
   * \brief Communicator type: 'socket'
213
   */
214
215
216
217
  const std::string &NetType() const override {
    static const std::string net_type = "socket";
    return net_type;
  }
218

219
 private:
220
221
222
223
224
  struct RecvContext {
    int64_t data_size = -1;
    int64_t received_bytes = 0;
    char *buffer = nullptr;
  };
225
226
227
228
229
230
  /*!
   * \brief number of sender
   */
  int num_sender_;

  /*!
231
   * \brief server socket for listening connections
232
   */ 
233
  TCPSocket* server_socket_;
234
235

  /*!
236
   * \brief socket for each client connections
237
   */ 
238
239
  std::vector<std::unordered_map<int /* Sender (virutal) ID */,
    std::shared_ptr<TCPSocket>>> sockets_;
240
241

  /*!
242
   * \brief Message queue for each socket connection
243
   */ 
244
245
  std::unordered_map<int /* Sender (virtual) ID */,
    std::shared_ptr<MessageQueue>> msg_queue_;
246
  std::unordered_map<int, std::shared_ptr<MessageQueue>>::iterator mq_iter_;
247

248
  /*!
249
   * \brief Independent thead
250
   */ 
251
  std::vector<std::shared_ptr<std::thread>> threads_;
252

253
  /*!
254
255
256
257
258
259
260
261
262
   * \brief queue_sem_ semphore to indicate number of messages in multiple
   * message queues to prevent busy wait of Recv
   */
  runtime::Semaphore queue_sem_;

  /*!
   * \brief Recv-loop for each thread
   * \param sockets client sockets of current thread
   * \param queue message queues of current thread
263
264
265
   *
   * Note that, the RecvLoop will finish its loop-job and exit thread
   * when the main thread invokes Signal() API on the message queue.
266
   */ 
267
268
269
270
271
272
  static void RecvLoop(
    std::unordered_map<int /* Sender (virtual) ID */,
      std::shared_ptr<TCPSocket>> sockets,
    std::unordered_map<int /* Sender (virtual) ID */,
      std::shared_ptr<MessageQueue>> queues,
    runtime::Semaphore *queue_sem);
273
274
275
276
277
};

}  // namespace network
}  // namespace dgl

278
#endif  // DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_