"vscode:/vscode.git/clone" did not exist on "c836efcfdc9e331d6ae6dfa33cae9052803a02aa"
msg_queue.cc 2.75 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/*!
 *  Copyright (c) 2019 by Contributors
 * \file msg_queue.cc
 * \brief Message queue for DGL distributed training.
 */
#include <dmlc/logging.h>
#include <cstring>

#include "msg_queue.h"

namespace dgl {
namespace network {

using std::string;

MessageQueue::MessageQueue(int64_t queue_size, int num_producers) {
17
18
  CHECK_GE(queue_size, 0);
  CHECK_GE(num_producers, 0);
19
20
21
22
23
  queue_size_ = queue_size;
  free_size_ = queue_size;
  num_producers_ = num_producers;
}

24
STATUS MessageQueue::Add(Message msg, bool is_blocking) {
25
  // check if message is too long to fit into the queue
26
27
28
  if (msg.size > queue_size_) {
    LOG(WARNING) << "Message is larger than the queue.";
    return MSG_GT_SIZE;
29
  }
30
31
32
  if (msg.size <= 0) {
    LOG(WARNING) << "Message size (" << msg.size << ") is negative or zero.";
    return MSG_LE_ZERO;
33
34
35
  }
  std::unique_lock<std::mutex> lock(mutex_);
  if (finished_producers_.size() >= num_producers_) {
36
37
    LOG(WARNING) << "Message queue is closed.";
    return QUEUE_CLOSE;
38
  }
39
40
  if (msg.size > free_size_ && !is_blocking) {
    return QUEUE_FULL;
41
42
  }
  cond_not_full_.wait(lock, [&]() {
43
    return msg.size <= free_size_;
44
  });
45
46
47
  // Add data pointer to queue
  queue_.push(msg);
  free_size_ -= msg.size;
48
49
50
  // not empty signal
  cond_not_empty_.notify_one();

51
  return ADD_SUCCESS;
52
53
}

54
STATUS MessageQueue::Remove(Message* msg, bool is_blocking) {
55
  std::unique_lock<std::mutex> lock(mutex_);
56
  if (queue_.empty()) {
57
    if (!is_blocking) {
58
      return QUEUE_EMPTY;
59
60
    }
    if (finished_producers_.size() >= num_producers_) {
61
62
      LOG(WARNING) << "Message queue is closed.";
      return QUEUE_CLOSE;
63
64
65
66
    }
  }

  cond_not_empty_.wait(lock, [this] {
67
    return !queue_.empty() || exit_flag_.load();
68
  });
69
70
71
  if (finished_producers_.size() >= num_producers_ && queue_.empty()) {
    LOG(WARNING) << "Message queue is closed.";
    return QUEUE_CLOSE;
72
73
  }

74
75
76
77
78
79
  Message & old_msg = queue_.front();
  queue_.pop();
  msg->data = old_msg.data;
  msg->size = old_msg.size;
  msg->deallocator = old_msg.deallocator;
  free_size_ += old_msg.size;
80
81
  cond_not_full_.notify_one();

82
  return REMOVE_SUCCESS;
83
84
}

85
void MessageQueue::SignalFinished(int producer_id) {
86
87
88
89
90
91
92
93
94
95
  std::lock_guard<std::mutex> lock(mutex_);
  finished_producers_.insert(producer_id);
  // if all producers have finished, consumers should be
  // waken up to get this signal
  if (finished_producers_.size() >= num_producers_) {
    exit_flag_.store(true);
    cond_not_empty_.notify_all();
  }
}

96
97
98
99
100
bool MessageQueue::Empty() const {
  std::lock_guard<std::mutex> lock(mutex_);
  return queue_.size() == 0;
}

101
102
bool MessageQueue::EmptyAndNoMoreAdd() const {
  std::lock_guard<std::mutex> lock(mutex_);
103
  return queue_.size() == 0 &&
104
105
106
107
108
         finished_producers_.size() >= num_producers_;
}

}  // namespace network
}  // namespace dgl