msg_queue.h 4.45 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2019 by Contributors
3
4
 * @file msg_queue.h
 * @brief Message queue for DGL distributed training.
5
 */
6
7
#ifndef DGL_RPC_NETWORK_MSG_QUEUE_H_
#define DGL_RPC_NETWORK_MSG_QUEUE_H_
8

9
10
#include <dgl/runtime/ndarray.h>

11
12
13
14
#include <atomic>
#include <condition_variable>
#include <functional>
#include <mutex>
15
16
17
18
19
20
21
22
#include <queue>
#include <set>
#include <string>
#include <utility>  // for pair

namespace dgl {
namespace network {

23
24
25
typedef int STATUS;

/*!
26
 * @brief Status code of message queue
27
 */
28
29
30
31
32
33
34
#define ADD_SUCCESS 3400     // Add message successfully
#define MSG_GT_SIZE 3401     // Message size beyond queue size
#define MSG_LE_ZERO 3402     // Message size is not a positive number
#define QUEUE_CLOSE 3403     // Cannot add message when queue is closed
#define QUEUE_FULL 3404      // Cannot add message when queue is full
#define REMOVE_SUCCESS 3405  // Remove message successfully
#define QUEUE_EMPTY 3406     // Cannot remove when queue is empty
35
36

/*!
37
 * @brief Message used by network communicator and message queue.
38
39
40
 */
struct Message {
  /*!
41
   * @brief Constructor
42
   */
43
  Message() {}
44
45

  /*!
46
   * @brief Constructor
47
   */
48
  Message(char* data_ptr, int64_t data_size)
49
      : data(data_ptr), size(data_size) {}
50
51

  /*!
52
   * @brief message data
53
54
55
   */
  char* data;
  /*!
56
   * @brief message size in bytes
57
58
   */
  int64_t size;
59
  /*!
60
   * @brief message receiver id
61
62
   */
  int receiver_id = -1;
63
  /*!
64
   * @brief user-defined deallocator, which can be nullptr
65
66
67
68
   */
  std::function<void(Message*)> deallocator = nullptr;
};

69
/*!
70
 * @brief Free memory buffer of message
71
 */
72
inline void DefaultMessageDeleter(Message* msg) { delete[] msg->data; }
73
74

/*!
75
 * @brief Message Queue for network communication.
76
 *
77
78
79
80
81
82
83
84
85
 * MessageQueue is FIFO queue that adopts producer/consumer model for data
 * message. It supports one or more producer threads and one or more consumer
 * threads. Producers invokes Add() to push data message into the queue, and
 * consumers invokes Remove() to pop data message from queue. Add() and Remove()
 * use two condition variables to synchronize producer threads and consumer
 * threads. Each producer invokes SignalFinished(producer_id) to claim that it
 * is about to finish, where producer_id is an integer uniquely identify a
 * producer thread. This signaling mechanism prevents consumers from waiting
 * after all producers have finished their jobs.
86
 *
87
 * MessageQueue is thread-safe.
88
 *
89
90
91
92
 */
class MessageQueue {
 public:
  /*!
93
94
95
   * @brief MessageQueue constructor
   * @param queue_size size (bytes) of message queue
   * @param num_producers number of producers, use 1 by default
96
   */
97
98
  explicit MessageQueue(
      int64_t queue_size /* in bytes */, int num_producers = 1);
99
100

  /*!
101
   * @brief MessageQueue deconstructor
102
   */
103
  ~MessageQueue() {}
104
105

  /*!
106
107
108
109
   * @brief Add message to the queue
   * @param msg data message
   * @param is_blocking Blocking if cannot add, else return
   * @return Status code
110
   */
111
  STATUS Add(Message msg, bool is_blocking = true);
112
113

  /*!
114
115
116
117
   * @brief Remove message from the queue
   * @param msg pointer of data msg
   * @param is_blocking Blocking if cannot remove, else return
   * @return Status code
118
   */
119
  STATUS Remove(Message* msg, bool is_blocking = true);
120
121

  /*!
122
123
   * @brief Signal that producer producer_id will no longer produce anything
   * @param producer_id An integer uniquely to identify a producer thread
124
   */
125
  void SignalFinished(int producer_id);
126
127

  /*!
128
   * @return true if queue is empty.
129
   */
130
  bool Empty() const;
131
132

  /*!
133
   * @return true if queue is empty and all num_producers have signaled.
134
135
136
137
   */
  bool EmptyAndNoMoreAdd() const;

 protected:
138
  /*!
139
   * @brief message queue
140
   */
141
  std::queue<Message> queue_;
142

143
  /*!
144
   * @brief Size of the queue in bytes
145
146
147
   */
  int64_t queue_size_;

148
  /*!
149
   * @brief Free size of the queue
150
151
152
   */
  int64_t free_size_;

153
  /*!
154
   * @brief Used to check all producers will no longer produce anything
155
   */
Da Zheng's avatar
Da Zheng committed
156
  size_t num_producers_;
157

158
  /*!
159
   * @brief Store finished producer id
160
161
162
   */
  std::set<int /* producer_id */> finished_producers_;

163
  /*!
164
   * @brief Condition when consumer should wait
165
166
167
   */
  std::condition_variable cond_not_full_;

168
  /*!
169
   * @brief Condition when producer should wait
170
171
172
   */
  std::condition_variable cond_not_empty_;

173
  /*!
174
   * @brief Signal for exit wait
175
176
177
   */
  std::atomic<bool> exit_flag_{false};

178
  /*!
179
   * @brief Protect all above data and conditions
180
181
182
183
184
185
186
   */
  mutable std::mutex mutex_;
};

}  // namespace network
}  // namespace dgl

187
#endif  // DGL_RPC_NETWORK_MSG_QUEUE_H_