msg_queue.cc 5.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
/*!
 *  Copyright (c) 2019 by Contributors
 * \file msg_queue.cc
 * \brief Message queue for DGL distributed training.
 */
#include <dmlc/logging.h>
#include <cstring>

#include "msg_queue.h"

namespace dgl {
namespace network {

using std::string;

MessageQueue::MessageQueue(int64_t queue_size, int num_producers) {
  CHECK_LT(0, queue_size);
  try {
    queue_ = new char[queue_size];
  } catch(const std::bad_alloc&) {
    LOG(FATAL) << "Not enough memory for message queue.";
  }
  memset(queue_, '\0', queue_size);

  queue_size_ = queue_size;
  free_size_ = queue_size;
  write_pointer_ = 0;
  num_producers_ = num_producers;
}

MessageQueue::~MessageQueue() {
  std::lock_guard<std::mutex> lock(mutex_);
  if (nullptr != queue_) {
    delete [] queue_;
    queue_ = nullptr;
  }
}

int64_t MessageQueue::Add(const char* src, int64_t size, bool is_blocking) {
  // check if message is too long to fit into the queue
  if (size > queue_size_) {
    LOG(ERROR) << "Message is larger than the queue.";
    return -1;
  }
  if (size <= 0) {
    LOG(ERROR) << "Message size (" << size << ") is negative or zero.";
    return -1;
  }
  std::unique_lock<std::mutex> lock(mutex_);
  if (finished_producers_.size() >= num_producers_) {
    LOG(ERROR) << "Can't add to buffer when flag_no_more is set.";
    return -1;
  }
  if (size > free_size_ && !is_blocking) {
    LOG(WARNING) << "Queue is full and message lost.";
    return 0;
  }
  cond_not_full_.wait(lock, [&]() {
    return size <= free_size_;
  });
  // Write data into buffer:
  // If there has enough space on tail of buffer, just append data
  // else, write till in the end of buffer and return to head of buffer
  message_positions_.push(std::make_pair(write_pointer_, size));
  free_size_ -= size;
  if (write_pointer_ + size <= queue_size_) {
    memcpy(&queue_[write_pointer_], src, size);
    write_pointer_ += size;
    if (write_pointer_ == queue_size_) {
      write_pointer_ = 0;
    }
  } else {
    int64_t size_partial = queue_size_ - write_pointer_;
    memcpy(&queue_[write_pointer_], src, size_partial);
    memcpy(queue_, &src[size_partial], size - size_partial);
    write_pointer_ = size - size_partial;
  }

  // not empty signal
  cond_not_empty_.notify_one();

  return size;
}

int64_t MessageQueue::Add(const string &src, bool is_blocking) {
  return Add(src.data(), src.size(), is_blocking);
}

int64_t MessageQueue::Remove(char *dest, int64_t max_size, bool is_blocking) {
  int64_t retval;

  std::unique_lock<std::mutex> lock(mutex_);
  if (message_positions_.empty()) {
    if (!is_blocking) {
      return 0;
    }
    if (finished_producers_.size() >= num_producers_) {
      return 0;
    }
  }

  cond_not_empty_.wait(lock, [this] {
    return !message_positions_.empty() || exit_flag_.load();
  });
  if (finished_producers_.size() >= num_producers_) {
    return 0;
  }

  MessagePosition & pos = message_positions_.front();
  // check if message is too long
  if (pos.second > max_size) {
    LOG(ERROR) << "Message size exceeds limit, information lost.";
    retval = -1;
  } else {
    // read from buffer:
    // if this message stores in consecutive memory, just read
    // else, read from buffer tail then return to the head
    if (pos.first + pos.second <= queue_size_) {
      memcpy(dest, &queue_[pos.first], pos.second);
    } else {
      int64_t size_partial = queue_size_ - pos.first;
      memcpy(dest, &queue_[pos.first], size_partial);
      memcpy(&dest[size_partial], queue_, pos.second - size_partial);
    }
    retval = pos.second;
  }
  free_size_ += pos.second;
  message_positions_.pop();

  cond_not_full_.notify_one();

  return retval;
}

int64_t MessageQueue::Remove(string *dest, bool is_blocking) {
  int64_t retval;

  std::unique_lock<std::mutex> lock(mutex_);
  if (message_positions_.empty()) {
    if (!is_blocking) {
      return 0;
    }
    if (finished_producers_.size() >= num_producers_) {
      return 0;
    }
  }

  cond_not_empty_.wait(lock, [this] {
    return !message_positions_.empty() || exit_flag_.load();
  });

  MessagePosition & pos = message_positions_.front();
  // read from buffer:
  // if this message stores in consecutive memory, just read
  // else, read from buffer tail then return to the head
  if (pos.first + pos.second <= queue_size_) {
    dest->assign(&queue_[pos.first], pos.second);
  } else {
    int64_t size_partial = queue_size_ - pos.first;
    dest->assign(&queue_[pos.first], size_partial);
    dest->append(queue_, pos.second - size_partial);
  }
  retval = pos.second;
  free_size_ += pos.second;
  message_positions_.pop();

  cond_not_full_.notify_one();

  return retval;
}

void MessageQueue::Signal(int producer_id) {
  std::lock_guard<std::mutex> lock(mutex_);
  finished_producers_.insert(producer_id);
  // if all producers have finished, consumers should be
  // waken up to get this signal
  if (finished_producers_.size() >= num_producers_) {
    exit_flag_.store(true);
    cond_not_empty_.notify_all();
  }
}

bool MessageQueue::EmptyAndNoMoreAdd() const {
  std::lock_guard<std::mutex> lock(mutex_);
  return message_positions_.size() == 0 &&
         finished_producers_.size() >= num_producers_;
}

}  // namespace network
}  // namespace dgl