socket_communicator.h 7.18 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2019 by Contributors
 * \file communicator.h
 * \brief SocketCommunicator for DGL distributed training.
 */
6
7
#ifndef DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
#define DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
8
9
10
11

#include <thread>
#include <vector>
#include <string>
12
#include <unordered_map>
13
#include <memory>
14

15
#include "../../runtime/semaphore_wrapper.h"
16
17
18
#include "communicator.h"
#include "msg_queue.h"
#include "tcp_socket.h"
19
#include "common.h"
20
21
22
23

namespace dgl {
namespace network {

24
static constexpr int kMaxTryCount = 1024;    // maximal connection: 1024
25
static constexpr int kTimeOut = 10 * 60;     // 10 minutes (in seconds) for socket timeout
26
static constexpr int kMaxConnection = 1024;  // maximal connection: 1024
27
28

/*!
29
 * \breif Networking address
30
 */
31
32
33
struct IPAddr {
  std::string ip;
  int port;
34
35
36
};

/*!
37
 * \brief SocketSender for DGL distributed training.
38
 *
39
 * SocketSender is the communicator implemented by tcp socket.
40
41
 */
class SocketSender : public Sender {
42
43
 public:
  /*!
44
45
   * \brief Sender constructor
   * \param queue_size size of message queue 
46
   * \param max_thread_count size of thread pool. 0 for no limit
47
   */
48
49
  SocketSender(int64_t queue_size, int max_thread_count)
    : Sender(queue_size, max_thread_count) {}
50
51
52
53

  /*!
   * \brief Add receiver's address and ID to the sender's namebook
   * \param addr Networking address, e.g., 'socket://127.0.0.1:50091', 'mpi://0'
54
   * \param id receiver's ID
55
56
   *
   * AddReceiver() is not thread-safe and only one thread can invoke this API.
57
   */
58
  void AddReceiver(const char* addr, int recv_id);
59

60
  /*!
61
   * \brief Connect with all the Receivers
62
63
64
   * \return True for success and False for fail
   *
   * Connect() is not thread-safe and only one thread can invoke this API.
65
   */
66
  bool Connect();
67
68

  /*!
69
70
   * \brief Send data to specified Receiver. Actually pushing message to message queue.
   * \param msg data message
71
   * \param recv_id receiver's ID
72
73
74
75
76
77
78
79
   * \return Status code
   *
   * (1) The send is non-blocking. There is no guarantee that the message has been 
   *     physically sent out when the function returns.
   * (2) The communicator will assume the responsibility of the given message.
   * (3) The API is multi-thread safe.
   * (4) Messages sent to the same receiver are guaranteed to be received in the same order. 
   *     There is no guarantee for messages sent to different receivers.
80
   */
81
  STATUS Send(Message msg, int recv_id);
82
83

  /*!
84
85
86
   * \brief Finalize SocketSender
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
87
88
89
   */
  void Finalize();

90
  /*!
91
   * \brief Communicator type: 'socket'
92
   */
93
  inline std::string Type() const { return std::string("socket"); }
94

95
 private:
96
  /*!
97
98
   * \brief socket for each connection of receiver
   */ 
99
100
  std::vector<std::unordered_map<int /* receiver ID */,
    std::shared_ptr<TCPSocket>>> sockets_;
101

102
  /*!
103
   * \brief receivers' address
104
   */ 
105
  std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;
106

107
  /*!
108
   * \brief message queue for each thread
109
   */ 
110
  std::vector<std::shared_ptr<MessageQueue>> msg_queue_;
111
112

  /*!
113
   * \brief Independent thread
114
   */ 
115
  std::vector<std::shared_ptr<std::thread>> threads_;
116
117

  /*!
118
119
120
   * \brief Send-loop for each thread
   * \param sockets TCPSockets for current thread
   * \param queue message_queue for current thread
121
122
123
124
   * 
   * Note that, the SendLoop will finish its loop-job and exit thread
   * when the main thread invokes Signal() API on the message queue.
   */
125
126
127
128
  static void SendLoop(
    std::unordered_map<int /* Receiver (virtual) ID */,
      std::shared_ptr<TCPSocket>> sockets,
    std::shared_ptr<MessageQueue> queue);
129
130
131
};

/*!
132
 * \brief SocketReceiver for DGL distributed training.
133
 *
134
 * SocketReceiver is the communicator implemented by tcp socket.
135
136
137
138
 */
class SocketReceiver : public Receiver {
 public:
  /*!
139
140
   * \brief Receiver constructor
   * \param queue_size size of message queue.
141
   * \param max_thread_count size of thread pool. 0 for no limit
142
   */
143
144
  SocketReceiver(int64_t queue_size, int max_thread_count)
    : Receiver(queue_size, max_thread_count) {}
145
146

  /*!
147
148
149
150
151
152
   * \brief Wait for all the Senders to connect
   * \param addr Networking address, e.g., 'socket://127.0.0.1:50051', 'mpi://0'
   * \param num_sender total number of Senders
   * \return True for success and False for fail
   *
   * Wait() is not thread-safe and only one thread can invoke this API.
153
   */
154
155
156
157
158
159
160
161
162
163
164
165
  bool Wait(const char* addr, int num_sender);

  /*!
   * \brief Recv data from Sender. Actually removing data from msg_queue.
   * \param msg pointer of data message
   * \param send_id which sender current msg comes from
   * \return Status code
   *
   * (1) The Recv() API is blocking, which will not 
   *     return until getting data from message queue.
   * (2) The Recv() API is thread-safe.
   * (3) Memory allocated by communicator but will not own it after the function returns.
166
   */
167
168
169
170
171
172
173
174
175
176
177
178
179
180
  STATUS Recv(Message* msg, int* send_id);

  /*!
   * \brief Recv data from a specified Sender. Actually removing data from msg_queue.
   * \param msg pointer of data message
   * \param send_id sender's ID
   * \return Status code
   *
   * (1) The RecvFrom() API is blocking, which will not 
   *     return until getting data from message queue.
   * (2) The RecvFrom() API is thread-safe.
   * (3) Memory allocated by communicator but will not own it after the function returns.
   */
  STATUS RecvFrom(Message* msg, int send_id);
181

182
  /*!
183
184
185
   * \brief Finalize SocketReceiver
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
186
   */
187
  void Finalize();
188
189

  /*!
190
   * \brief Communicator type: 'socket'
191
   */
192
  inline std::string Type() const { return std::string("socket"); }
193

194
 private:
195
196
197
198
199
  struct RecvContext {
    int64_t data_size = -1;
    int64_t received_bytes = 0;
    char *buffer = nullptr;
  };
200
201
202
203
204
205
  /*!
   * \brief number of sender
   */
  int num_sender_;

  /*!
206
   * \brief server socket for listening connections
207
   */ 
208
  TCPSocket* server_socket_;
209
210

  /*!
211
   * \brief socket for each client connections
212
   */ 
213
214
  std::vector<std::unordered_map<int /* Sender (virutal) ID */,
    std::shared_ptr<TCPSocket>>> sockets_;
215
216

  /*!
217
   * \brief Message queue for each socket connection
218
   */ 
219
220
  std::unordered_map<int /* Sender (virtual) ID */,
    std::shared_ptr<MessageQueue>> msg_queue_;
221
  std::unordered_map<int, std::shared_ptr<MessageQueue>>::iterator mq_iter_;
222

223
  /*!
224
   * \brief Independent thead
225
   */ 
226
  std::vector<std::shared_ptr<std::thread>> threads_;
227

228
  /*!
229
230
231
232
233
234
235
236
237
   * \brief queue_sem_ semphore to indicate number of messages in multiple
   * message queues to prevent busy wait of Recv
   */
  runtime::Semaphore queue_sem_;

  /*!
   * \brief Recv-loop for each thread
   * \param sockets client sockets of current thread
   * \param queue message queues of current thread
238
239
240
   *
   * Note that, the RecvLoop will finish its loop-job and exit thread
   * when the main thread invokes Signal() API on the message queue.
241
   */ 
242
243
244
245
246
247
  static void RecvLoop(
    std::unordered_map<int /* Sender (virtual) ID */,
      std::shared_ptr<TCPSocket>> sockets,
    std::unordered_map<int /* Sender (virtual) ID */,
      std::shared_ptr<MessageQueue>> queues,
    runtime::Semaphore *queue_sem);
248
249
250
251
252
};

}  // namespace network
}  // namespace dgl

253
#endif  // DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_