"vscode:/vscode.git/clone" did not exist on "3386f466ac5400b91b2b39f62634774af8810850"
socket_pool.cc 2.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
/*!
 *  Copyright (c) 2021 by Contributors
 * \file socket_pool.cc
 * \brief Socket pool of nonblocking sockets for DGL distributed training.
 */
#include "socket_pool.h"

#include <dmlc/logging.h>
#include "tcp_socket.h"

#ifdef USE_EPOLL
#include <sys/epoll.h>
#endif

namespace dgl {
namespace network {

SocketPool::SocketPool() {
#ifdef USE_EPOLL
  epfd_ = epoll_create1(0);
  if (epfd_ < 0) {
    LOG(FATAL) << "SocketPool cannot create epfd";
  }
#endif
}

void SocketPool::AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id,
  int events) {
  int fd = socket->Socket();
  tcp_sockets_[fd] = socket;
  socket_ids_[fd] = socket_id;

#ifdef USE_EPOLL
  epoll_event e;
  e.data.fd = fd;
  if (events == READ) {
    e.events = EPOLLIN;
  } else if (events == WRITE) {
    e.events = EPOLLOUT;
  } else if (events == READ + WRITE) {
    e.events = EPOLLIN | EPOLLOUT;
  }
  if (epoll_ctl(epfd_, EPOLL_CTL_ADD, fd, &e) < 0) {
    LOG(FATAL) << "SocketPool cannot add socket";
  }
  socket->SetNonBlocking(true);
#else
  if (tcp_sockets_.size() > 1) {
    LOG(FATAL) << "SocketPool supports only one socket if not use epoll."
      "Please turn on USE_EPOLL on building";
  }
#endif
}

size_t SocketPool::RemoveSocket(std::shared_ptr<TCPSocket> socket) {
  int fd = socket->Socket();
  socket_ids_.erase(fd);
  tcp_sockets_.erase(fd);
#ifdef USE_EPOLL
  epoll_ctl(epfd_, EPOLL_CTL_DEL, fd, NULL);
#endif
  return socket_ids_.size();
}

SocketPool::~SocketPool() {
#ifdef USE_EPOLL
  for (auto& id : socket_ids_) {
    int fd = id.first;
    epoll_ctl(epfd_, EPOLL_CTL_DEL, fd, NULL);
  }
#endif
}

std::shared_ptr<TCPSocket> SocketPool::GetActiveSocket(int* socket_id) {
  if (socket_ids_.empty()) {
    return nullptr;
  }

  for (;;) {
    while (pending_fds_.empty()) {
      Wait();
    }
    int fd = pending_fds_.front();
    pending_fds_.pop();

    // Check if this socket is not removed
    if (socket_ids_.find(fd) != socket_ids_.end()) {
      *socket_id = socket_ids_[fd];
      return tcp_sockets_[fd];
    }
  }

  return nullptr;
}

void SocketPool::Wait() {
#ifdef USE_EPOLL
  static const int MAX_EVENTS = 10;
  epoll_event events[MAX_EVENTS];
  int nfd = epoll_wait(epfd_, events, MAX_EVENTS, -1 /*Timeout*/);
  for (int i = 0; i < nfd; ++i) {
    pending_fds_.push(events[i].data.fd);
  }
#else
  pending_fds_.push(tcp_sockets_.begin()->second->Socket());
#endif
}

}  // namespace network
}  // namespace dgl