socket_communicator.h 6.41 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2019 by Contributors
 * \file communicator.h
 * \brief SocketCommunicator for DGL distributed training.
 */
6
7
#ifndef DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
#define DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
8
9
10
11

#include <thread>
#include <vector>
#include <string>
12
#include <unordered_map>
13
#include <memory>
14
15
16
17

#include "communicator.h"
#include "msg_queue.h"
#include "tcp_socket.h"
18
#include "common.h"
19
20
21
22

namespace dgl {
namespace network {

23
24
25
static constexpr int kMaxTryCount = 1024;    // maximal connection: 1024
static constexpr int kTimeOut = 10;          // 10 minutes for socket timeout
static constexpr int kMaxConnection = 1024;  // maximal connection: 1024
26
27

/*!
28
 * \breif Networking address
29
 */
30
31
32
struct IPAddr {
  std::string ip;
  int port;
33
34
35
};

/*!
36
 * \brief SocketSender for DGL distributed training.
37
 *
38
 * SocketSender is the communicator implemented by tcp socket.
39
40
 */
class SocketSender : public Sender {
41
42
 public:
  /*!
43
44
45
46
47
48
49
50
   * \brief Sender constructor
   * \param queue_size size of message queue 
   */
  explicit SocketSender(int64_t queue_size) : Sender(queue_size) {}

  /*!
   * \brief Add receiver's address and ID to the sender's namebook
   * \param addr Networking address, e.g., 'socket://127.0.0.1:50091', 'mpi://0'
51
   * \param id receiver's ID
52
53
   *
   * AddReceiver() is not thread-safe and only one thread can invoke this API.
54
   */
55
  void AddReceiver(const char* addr, int recv_id);
56

57
  /*!
58
   * \brief Connect with all the Receivers
59
60
61
   * \return True for success and False for fail
   *
   * Connect() is not thread-safe and only one thread can invoke this API.
62
   */
63
  bool Connect();
64
65

  /*!
66
67
   * \brief Send data to specified Receiver. Actually pushing message to message queue.
   * \param msg data message
68
   * \param recv_id receiver's ID
69
70
71
72
73
74
75
76
   * \return Status code
   *
   * (1) The send is non-blocking. There is no guarantee that the message has been 
   *     physically sent out when the function returns.
   * (2) The communicator will assume the responsibility of the given message.
   * (3) The API is multi-thread safe.
   * (4) Messages sent to the same receiver are guaranteed to be received in the same order. 
   *     There is no guarantee for messages sent to different receivers.
77
   */
78
  STATUS Send(Message msg, int recv_id);
79
80

  /*!
81
82
83
   * \brief Finalize SocketSender
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
84
85
86
   */
  void Finalize();

87
  /*!
88
   * \brief Communicator type: 'socket'
89
   */
90
  inline std::string Type() const { return std::string("socket"); }
91

92
 private:
93
  /*!
94
95
96
   * \brief socket for each connection of receiver
   */ 
  std::unordered_map<int /* receiver ID */, std::shared_ptr<TCPSocket>> sockets_;
97

98
  /*!
99
   * \brief receivers' address
100
   */ 
101
  std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;
102

103
  /*!
104
   * \brief message queue for each socket connection
105
   */ 
106
  std::unordered_map<int /* receiver ID */, std::shared_ptr<MessageQueue>> msg_queue_;
107
108

  /*!
109
   * \brief Independent thread for each socket connection
110
   */ 
111
112
113
114
115
116
117
118
119
120
121
  std::unordered_map<int /* receiver ID */, std::shared_ptr<std::thread>> threads_;

  /*!
   * \brief Send-loop for each socket in per-thread
   * \param socket TCPSocket for current connection
   * \param queue message_queue for current connection
   * 
   * Note that, the SendLoop will finish its loop-job and exit thread
   * when the main thread invokes Signal() API on the message queue.
   */
  static void SendLoop(TCPSocket* socket, MessageQueue* queue);
122
123
124
};

/*!
125
 * \brief SocketReceiver for DGL distributed training.
126
 *
127
 * SocketReceiver is the communicator implemented by tcp socket.
128
129
130
131
 */
class SocketReceiver : public Receiver {
 public:
  /*!
132
133
   * \brief Receiver constructor
   * \param queue_size size of message queue.
134
   */
135
  explicit SocketReceiver(int64_t queue_size) : Receiver(queue_size) {}
136
137

  /*!
138
139
140
141
142
143
   * \brief Wait for all the Senders to connect
   * \param addr Networking address, e.g., 'socket://127.0.0.1:50051', 'mpi://0'
   * \param num_sender total number of Senders
   * \return True for success and False for fail
   *
   * Wait() is not thread-safe and only one thread can invoke this API.
144
   */
145
146
147
148
149
150
151
152
153
154
155
156
  bool Wait(const char* addr, int num_sender);

  /*!
   * \brief Recv data from Sender. Actually removing data from msg_queue.
   * \param msg pointer of data message
   * \param send_id which sender current msg comes from
   * \return Status code
   *
   * (1) The Recv() API is blocking, which will not 
   *     return until getting data from message queue.
   * (2) The Recv() API is thread-safe.
   * (3) Memory allocated by communicator but will not own it after the function returns.
157
   */
158
159
160
161
162
163
164
165
166
167
168
169
170
171
  STATUS Recv(Message* msg, int* send_id);

  /*!
   * \brief Recv data from a specified Sender. Actually removing data from msg_queue.
   * \param msg pointer of data message
   * \param send_id sender's ID
   * \return Status code
   *
   * (1) The RecvFrom() API is blocking, which will not 
   *     return until getting data from message queue.
   * (2) The RecvFrom() API is thread-safe.
   * (3) Memory allocated by communicator but will not own it after the function returns.
   */
  STATUS RecvFrom(Message* msg, int send_id);
172

173
  /*!
174
175
176
   * \brief Finalize SocketReceiver
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
177
   */
178
  void Finalize();
179
180

  /*!
181
   * \brief Communicator type: 'socket'
182
   */
183
  inline std::string Type() const { return std::string("socket"); }
184

185
 private:
186
187
188
189
190
191
  /*!
   * \brief number of sender
   */
  int num_sender_;

  /*!
192
   * \brief server socket for listening connections
193
   */ 
194
  TCPSocket* server_socket_;
195
196

  /*!
197
   * \brief socket for each client connections
198
   */ 
199
  std::unordered_map<int /* Sender (virutal) ID */, std::shared_ptr<TCPSocket>> sockets_;
200
201

  /*!
202
   * \brief Message queue for each socket connection
203
   */ 
204
  std::unordered_map<int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>> msg_queue_;
205

206
  /*!
207
   * \brief Independent thead for each socket connection
208
   */ 
209
  std::unordered_map<int /* Sender (virtual) ID */, std::shared_ptr<std::thread>> threads_;
210

211
  /*!
212
213
   * \brief Recv-loop for each socket in per-thread
   * \param socket client socket
214
   * \param queue message queue
215
216
217
   *
   * Note that, the RecvLoop will finish its loop-job and exit thread
   * when the main thread invokes Signal() API on the message queue.
218
   */ 
219
  static void RecvLoop(TCPSocket* socket, MessageQueue* queue);
220
221
222
223
224
};

}  // namespace network
}  // namespace dgl

225
#endif  // DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_