socket_communicator.h 6.37 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*!
 *  Copyright (c) 2019 by Contributors
 * \file communicator.h
 * \brief SocketCommunicator for DGL distributed training.
 */
#ifndef DGL_GRAPH_NETWORK_SOCKET_COMMUNICATOR_H_
#define DGL_GRAPH_NETWORK_SOCKET_COMMUNICATOR_H_

#include <thread>
#include <vector>
#include <string>
12
#include <unordered_map>
13
14
15
16

#include "communicator.h"
#include "msg_queue.h"
#include "tcp_socket.h"
17
#include "common.h"
18
19
20
21

namespace dgl {
namespace network {

22
23
24
static int kMaxTryCount = 1024;    // maximal connection: 1024
static int kTimeOut = 10;          // 10 minutes for socket timeout
static int kMaxConnection = 1024;  // maximal connection: 1024
25
26

/*!
27
 * \breif Networking address
28
 */
29
30
31
struct IPAddr {
  std::string ip;
  int port;
32
33
34
};

/*!
35
 * \brief SocketSender for DGL distributed training.
36
 *
37
 * SocketSender is the communicator implemented by tcp socket.
38
39
 */
class SocketSender : public Sender {
40
41
 public:
  /*!
42
43
44
45
46
47
48
49
   * \brief Sender constructor
   * \param queue_size size of message queue 
   */
  explicit SocketSender(int64_t queue_size) : Sender(queue_size) {}

  /*!
   * \brief Add receiver's address and ID to the sender's namebook
   * \param addr Networking address, e.g., 'socket://127.0.0.1:50091', 'mpi://0'
50
   * \param id receiver's ID
51
52
   *
   * AddReceiver() is not thread-safe and only one thread can invoke this API.
53
   */
54
  void AddReceiver(const char* addr, int recv_id);
55

56
  /*!
57
   * \brief Connect with all the Receivers
58
59
60
   * \return True for success and False for fail
   *
   * Connect() is not thread-safe and only one thread can invoke this API.
61
   */
62
  bool Connect();
63
64

  /*!
65
66
   * \brief Send data to specified Receiver. Actually pushing message to message queue.
   * \param msg data message
67
   * \param recv_id receiver's ID
68
69
70
71
72
73
74
75
   * \return Status code
   *
   * (1) The send is non-blocking. There is no guarantee that the message has been 
   *     physically sent out when the function returns.
   * (2) The communicator will assume the responsibility of the given message.
   * (3) The API is multi-thread safe.
   * (4) Messages sent to the same receiver are guaranteed to be received in the same order. 
   *     There is no guarantee for messages sent to different receivers.
76
   */
77
  STATUS Send(Message msg, int recv_id);
78
79

  /*!
80
81
82
   * \brief Finalize SocketSender
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
83
84
85
   */
  void Finalize();

86
  /*!
87
   * \brief Communicator type: 'socket'
88
   */
89
  inline std::string Type() const { return std::string("socket"); }
90

91
 private:
92
  /*!
93
94
95
   * \brief socket for each connection of receiver
   */ 
  std::unordered_map<int /* receiver ID */, std::shared_ptr<TCPSocket>> sockets_;
96

97
  /*!
98
   * \brief receivers' address
99
   */ 
100
  std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;
101

102
  /*!
103
   * \brief message queue for each socket connection
104
   */ 
105
  std::unordered_map<int /* receiver ID */, std::shared_ptr<MessageQueue>> msg_queue_;
106
107

  /*!
108
   * \brief Independent thread for each socket connection
109
   */ 
110
111
112
113
114
115
116
117
118
119
120
  std::unordered_map<int /* receiver ID */, std::shared_ptr<std::thread>> threads_;

  /*!
   * \brief Send-loop for each socket in per-thread
   * \param socket TCPSocket for current connection
   * \param queue message_queue for current connection
   * 
   * Note that, the SendLoop will finish its loop-job and exit thread
   * when the main thread invokes Signal() API on the message queue.
   */
  static void SendLoop(TCPSocket* socket, MessageQueue* queue);
121
122
123
};

/*!
124
 * \brief SocketReceiver for DGL distributed training.
125
 *
126
 * SocketReceiver is the communicator implemented by tcp socket.
127
128
129
130
 */
class SocketReceiver : public Receiver {
 public:
  /*!
131
132
   * \brief Receiver constructor
   * \param queue_size size of message queue.
133
   */
134
  explicit SocketReceiver(int64_t queue_size) : Receiver(queue_size) {}
135
136

  /*!
137
138
139
140
141
142
   * \brief Wait for all the Senders to connect
   * \param addr Networking address, e.g., 'socket://127.0.0.1:50051', 'mpi://0'
   * \param num_sender total number of Senders
   * \return True for success and False for fail
   *
   * Wait() is not thread-safe and only one thread can invoke this API.
143
   */
144
145
146
147
148
149
150
151
152
153
154
155
  bool Wait(const char* addr, int num_sender);

  /*!
   * \brief Recv data from Sender. Actually removing data from msg_queue.
   * \param msg pointer of data message
   * \param send_id which sender current msg comes from
   * \return Status code
   *
   * (1) The Recv() API is blocking, which will not 
   *     return until getting data from message queue.
   * (2) The Recv() API is thread-safe.
   * (3) Memory allocated by communicator but will not own it after the function returns.
156
   */
157
158
159
160
161
162
163
164
165
166
167
168
169
170
  STATUS Recv(Message* msg, int* send_id);

  /*!
   * \brief Recv data from a specified Sender. Actually removing data from msg_queue.
   * \param msg pointer of data message
   * \param send_id sender's ID
   * \return Status code
   *
   * (1) The RecvFrom() API is blocking, which will not 
   *     return until getting data from message queue.
   * (2) The RecvFrom() API is thread-safe.
   * (3) Memory allocated by communicator but will not own it after the function returns.
   */
  STATUS RecvFrom(Message* msg, int send_id);
171

172
  /*!
173
174
175
   * \brief Finalize SocketReceiver
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
176
   */
177
  void Finalize();
178
179

  /*!
180
   * \brief Communicator type: 'socket'
181
   */
182
  inline std::string Type() const { return std::string("socket"); }
183

184
 private:
185
186
187
188
189
190
  /*!
   * \brief number of sender
   */
  int num_sender_;

  /*!
191
   * \brief server socket for listening connections
192
   */ 
193
  TCPSocket* server_socket_;
194
195

  /*!
196
   * \brief socket for each client connections
197
   */ 
198
  std::unordered_map<int /* Sender (virutal) ID */, std::shared_ptr<TCPSocket>> sockets_;
199
200

  /*!
201
   * \brief Message queue for each socket connection
202
   */ 
203
  std::unordered_map<int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>> msg_queue_;
204

205
  /*!
206
   * \brief Independent thead for each socket connection
207
   */ 
208
  std::unordered_map<int /* Sender (virtual) ID */, std::shared_ptr<std::thread>> threads_;
209

210
  /*!
211
212
   * \brief Recv-loop for each socket in per-thread
   * \param socket client socket
213
   * \param queue message queue
214
215
216
   *
   * Note that, the RecvLoop will finish its loop-job and exit thread
   * when the main thread invokes Signal() API on the message queue.
217
   */ 
218
  static void RecvLoop(TCPSocket* socket, MessageQueue* queue);
219
220
221
222
223
224
};

}  // namespace network
}  // namespace dgl

#endif  // DGL_GRAPH_NETWORK_SOCKET_COMMUNICATOR_H_