msg_queue.cc 2.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/*!
 *  Copyright (c) 2019 by Contributors
 * \file msg_queue.cc
 * \brief Message queue for DGL distributed training.
 */
#include <dmlc/logging.h>
#include <cstring>

#include "msg_queue.h"

namespace dgl {
namespace network {

using std::string;

MessageQueue::MessageQueue(int64_t queue_size, int num_producers) {
17
18
  CHECK_GE(queue_size, 0);
  CHECK_GE(num_producers, 0);
19
20
21
22
23
  queue_size_ = queue_size;
  free_size_ = queue_size;
  num_producers_ = num_producers;
}

24
STATUS MessageQueue::Add(Message msg, bool is_blocking) {
25
  // check if message is too long to fit into the queue
26
27
28
  if (msg.size > queue_size_) {
    LOG(WARNING) << "Message is larger than the queue.";
    return MSG_GT_SIZE;
29
  }
30
31
32
  if (msg.size <= 0) {
    LOG(WARNING) << "Message size (" << msg.size << ") is negative or zero.";
    return MSG_LE_ZERO;
33
34
35
  }
  std::unique_lock<std::mutex> lock(mutex_);
  if (finished_producers_.size() >= num_producers_) {
36
    return QUEUE_CLOSE;
37
  }
38
39
  if (msg.size > free_size_ && !is_blocking) {
    return QUEUE_FULL;
40
41
  }
  cond_not_full_.wait(lock, [&]() {
42
    return msg.size <= free_size_;
43
  });
44
45
46
  // Add data pointer to queue
  queue_.push(msg);
  free_size_ -= msg.size;
47
48
49
  // not empty signal
  cond_not_empty_.notify_one();

50
  return ADD_SUCCESS;
51
52
}

53
STATUS MessageQueue::Remove(Message* msg, bool is_blocking) {
54
  std::unique_lock<std::mutex> lock(mutex_);
55
  if (queue_.empty()) {
56
    if (!is_blocking) {
57
      return QUEUE_EMPTY;
58
59
    }
    if (finished_producers_.size() >= num_producers_) {
60
      return QUEUE_CLOSE;
61
62
63
64
    }
  }

  cond_not_empty_.wait(lock, [this] {
65
    return !queue_.empty() || exit_flag_.load();
66
  });
67
68
  if (finished_producers_.size() >= num_producers_ && queue_.empty()) {
    return QUEUE_CLOSE;
69
70
  }

71
  Message old_msg = queue_.front();
72
73
74
  queue_.pop();
  msg->data = old_msg.data;
  msg->size = old_msg.size;
75
  msg->receiver_id = old_msg.receiver_id;
76
77
  msg->deallocator = old_msg.deallocator;
  free_size_ += old_msg.size;
78
79
  cond_not_full_.notify_one();

80
  return REMOVE_SUCCESS;
81
82
}

83
void MessageQueue::SignalFinished(int producer_id) {
84
85
86
87
88
89
90
91
92
93
  std::lock_guard<std::mutex> lock(mutex_);
  finished_producers_.insert(producer_id);
  // if all producers have finished, consumers should be
  // waken up to get this signal
  if (finished_producers_.size() >= num_producers_) {
    exit_flag_.store(true);
    cond_not_empty_.notify_all();
  }
}

94
95
96
97
98
bool MessageQueue::Empty() const {
  std::lock_guard<std::mutex> lock(mutex_);
  return queue_.size() == 0;
}

99
100
bool MessageQueue::EmptyAndNoMoreAdd() const {
  std::lock_guard<std::mutex> lock(mutex_);
101
  return queue_.size() == 0 &&
102
103
104
105
106
         finished_producers_.size() >= num_producers_;
}

}  // namespace network
}  // namespace dgl