socket_pool.h 2.08 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021 by Contributors
3
4
 * @file socket_pool.h
 * @brief Socket pool of nonblocking sockets for DGL distributed training.
5
6
7
8
9
 */
#ifndef DGL_RPC_NETWORK_SOCKET_POOL_H_
#define DGL_RPC_NETWORK_SOCKET_POOL_H_

#include <memory>
10
11
#include <queue>
#include <unordered_map>
12
13
14
15
16
17

namespace dgl {
namespace network {

class TCPSocket;

18
/**
19
 * @brief SocketPool maintains a group of nonblocking sockets, and can provide
20
21
 * active sockets.
 * Currently SocketPool is based on epoll, a scalable I/O event notification
22
 * mechanism in Linux operating system.
23
24
25
 */
class SocketPool {
 public:
26
  /**
27
   * @brief socket mode read/receive
28
29
   */
  static const int READ = 1;
30
  /**
31
   * @brief socket mode write/send
32
33
   */
  static const int WRITE = 2;
34
  /**
35
   * @brief SocketPool constructor
36
37
38
   */
  SocketPool();

39
  /**
40
41
42
43
   * @brief Add a socket to SocketPool
   * @param socket tcp socket to add
   * @param socket_id receiver/sender id of the socket
   * @param events READ, WRITE or READ + WRITE
44
   */
45
46
  void AddSocket(
      std::shared_ptr<TCPSocket> socket, int socket_id, int events = READ);
47

48
  /**
49
50
51
   * @brief Remove socket from SocketPool
   * @param socket tcp socket to remove
   * @return number of remaing sockets in the pool
52
53
54
   */
  size_t RemoveSocket(std::shared_ptr<TCPSocket> socket);

55
  /**
56
   * @brief SocketPool destructor
57
58
59
   */
  ~SocketPool();

60
  /**
61
62
63
   * @brief Get current active socket. This is a blocking method
   * @param socket_id output parameter of the socket_id of active socket
   * @return active TCPSocket
64
65
66
67
   */
  std::shared_ptr<TCPSocket> GetActiveSocket(int* socket_id);

 private:
68
  /**
69
   * @brief Wait for event notification
70
71
72
   */
  void Wait();

73
  /**
74
   * @brief map from fd to TCPSocket
75
76
77
   */
  std::unordered_map<int, std::shared_ptr<TCPSocket>> tcp_sockets_;

78
  /**
79
   * @brief map from fd to socket_id
80
81
82
   */
  std::unordered_map<int, int> socket_ids_;

83
  /**
84
   * @brief fd for epoll base
85
86
87
   */
  int epfd_;

88
  /**
89
   * @brief queue for current active fds
90
91
92
93
94
95
96
97
   */
  std::queue<int> pending_fds_;
};

}  // namespace network
}  // namespace dgl

#endif  // DGL_RPC_NETWORK_SOCKET_POOL_H_