socket_pool.h 2.08 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2021 by Contributors
3
4
 * @file socket_pool.h
 * @brief Socket pool of nonblocking sockets for DGL distributed training.
5
6
7
8
9
 */
#ifndef DGL_RPC_NETWORK_SOCKET_POOL_H_
#define DGL_RPC_NETWORK_SOCKET_POOL_H_

#include <memory>
10
11
#include <queue>
#include <unordered_map>
12
13
14
15
16
17
18

namespace dgl {
namespace network {

class TCPSocket;

/*!
19
 * @brief SocketPool maintains a group of nonblocking sockets, and can provide
20
21
 * active sockets.
 * Currently SocketPool is based on epoll, a scalable I/O event notification
22
 * mechanism in Linux operating system.
23
24
25
26
 */
class SocketPool {
 public:
  /*!
27
   * @brief socket mode read/receive
28
29
30
   */
  static const int READ = 1;
  /*!
31
   * @brief socket mode write/send
32
33
34
   */
  static const int WRITE = 2;
  /*!
35
   * @brief SocketPool constructor
36
37
38
39
   */
  SocketPool();

  /*!
40
41
42
43
   * @brief Add a socket to SocketPool
   * @param socket tcp socket to add
   * @param socket_id receiver/sender id of the socket
   * @param events READ, WRITE or READ + WRITE
44
   */
45
46
  void AddSocket(
      std::shared_ptr<TCPSocket> socket, int socket_id, int events = READ);
47
48

  /*!
49
50
51
   * @brief Remove socket from SocketPool
   * @param socket tcp socket to remove
   * @return number of remaing sockets in the pool
52
53
54
55
   */
  size_t RemoveSocket(std::shared_ptr<TCPSocket> socket);

  /*!
56
   * @brief SocketPool destructor
57
58
59
60
   */
  ~SocketPool();

  /*!
61
62
63
   * @brief Get current active socket. This is a blocking method
   * @param socket_id output parameter of the socket_id of active socket
   * @return active TCPSocket
64
65
66
67
68
   */
  std::shared_ptr<TCPSocket> GetActiveSocket(int* socket_id);

 private:
  /*!
69
   * @brief Wait for event notification
70
71
72
73
   */
  void Wait();

  /*!
74
   * @brief map from fd to TCPSocket
75
76
77
78
   */
  std::unordered_map<int, std::shared_ptr<TCPSocket>> tcp_sockets_;

  /*!
79
   * @brief map from fd to socket_id
80
81
82
83
   */
  std::unordered_map<int, int> socket_ids_;

  /*!
84
   * @brief fd for epoll base
85
86
87
88
   */
  int epfd_;

  /*!
89
   * @brief queue for current active fds
90
91
92
93
94
95
96
97
   */
  std::queue<int> pending_fds_;
};

}  // namespace network
}  // namespace dgl

#endif  // DGL_RPC_NETWORK_SOCKET_POOL_H_