msg_queue.h 4.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
/*!
 *  Copyright (c) 2019 by Contributors
 * \file msg_queue.h
 * \brief Message queue for DGL distributed training.
 */
#ifndef DGL_GRAPH_NETWORK_MSG_QUEUE_H_
#define DGL_GRAPH_NETWORK_MSG_QUEUE_H_

#include <queue>
#include <set>
#include <string>
#include <utility>  // for pair
#include <mutex>
#include <condition_variable>
#include <atomic>

namespace dgl {
namespace network {

/*!
 * \brief Message Queue for DGL distributed training.
 *
 * MessageQueue is a circle queue for using the ring-buffer in a 
 * producer/consumer model. It supports one or more producer 
 * threads and one or more consumer threads. Producers invokes Add()
 * to push data elements into the queue, and consumers invokes
 * Remove() to pop data elements. Add() and Remove() use two condition
 * variables to synchronize producers and consumers. Each producer invokes
 * Signal(producer_id) to claim that it is about to finish, where 
 * producer_id is an integer uniquely identify a producer thread. This
 * signaling mechanism prevents consumers from waiting after all producers
 * have finished their jobs.
 * 
 */
class MessageQueue {
 public:
  /*!
   * \brief MessageQueue constructor
   * \param queue_size size of message queue
   * \param num_producers number of producers, use 1 by default
   */
  MessageQueue(int64_t queue_size /* in bytes */,
               int num_producers = 1);

  /*!
   * \brief MessageQueue deconstructor
   */
  ~MessageQueue();

  /*!
   * \brief Add data to the message queue
   * \param src The data pointer
   * \param size The size of data
   * \param is_blocking Block function if cannot add, else return
   * \return bytes added to the queue
   *   > 0 : size of message
   *   = 0 : no enough space for this message (when is_blocking = false)
   *   - 1 : error 
   */
  int64_t Add(const char* src, int64_t size, bool is_blocking = true);

  /*!
   * \brief Add data to the message queue
   * \param src The data string
   * \param is_blocking Block function if cannot add, else return
   * \return bytes added to queue
   *   > 0 : size of message
   *   = 0 : no enough space for this message (when is_blocking = false)
   *   - 1 : error 
   */
  int64_t Add(const std::string& src, bool is_blocking = true);

  /*!
   * \brief Remove message from the queue
   * \param dest The destination data pointer
   * \param max_size Maximal size of data
   * \param is_blocking Block function if cannot remove, else return
   * \return bytes removed from queue
   *   > 0 : size of message
   *   = 0 : queue is empty
   *   - 1 : error 
   */
  int64_t Remove(char *dest, int64_t max_size, bool is_blocking = true);

  /*!
   * \brief Remove message from the queue
   * \param dest The destination data string
   * \param is_blocking Block function if cannot remove, else return
   * \return bytes removed from queue
   *   > 0 : size of message
   *   = 0 : queue is empty
   *   - 1 : error 
   */
  int64_t Remove(std::string *dest, bool is_blocking = true);

  /*!
   * \brief Signal that producer producer_id will no longer produce anything
   * \param producer_id An integer uniquely to identify a producer thread
   */
  void Signal(int producer_id);

  /*!
   * \return true if queue is empty and all num_producers have signaled.
   */
  bool EmptyAndNoMoreAdd() const;

 protected:
  typedef std::pair<int64_t /* message_start_position in queue_ */,
                    int64_t /* message_length */> MessagePosition;

  /*! 
   * \brief Pointer to the queue 
   */
  char* queue_;

  /*! 
   * \brief Size of the queue in bytes 
   */
  int64_t queue_size_;

  /*! 
   * \brief Free size of the queue 
   */
  int64_t free_size_;

  /*! 
   * \brief Location in queue_ for where to write the next element 
   * Note that we do not need read_pointer since all messages were indexed
   * by message_postions_, and the first element in message_position_ 
   * denotes where we should read
   */
  int64_t write_pointer_;

  /*! 
   * \brief Used to check all producers will no longer produce anything 
   */
  int num_producers_;

  /*! 
   * \brief Messages in the queue 
   */
  std::queue<MessagePosition> message_positions_;

  /*! 
   * \brief Store finished producer id 
   */
  std::set<int /* producer_id */> finished_producers_;

  /*! 
   * \brief Condition when consumer should wait 
   */
  std::condition_variable cond_not_full_;

  /*! 
   * \brief Condition when producer should wait 
   */
  std::condition_variable cond_not_empty_;

  /*! 
   * \brief Signal for exit wait 
   */
  std::atomic<bool> exit_flag_{false};

  /*! 
   * \brief Protect all above data and conditions 
   */
  mutable std::mutex mutex_;
};

}  // namespace network
}  // namespace dgl

#endif  // DGL_GRAPH_NETWORK_MSG_QUEUE_H_