msg_queue.cc 2.63 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2019 by Contributors
 * \file msg_queue.cc
 * \brief Message queue for DGL distributed training.
 */
6
7
#include "msg_queue.h"

8
9
#include <dmlc/logging.h>

10
#include <cstring>
11
12
13
14
15
16
17

namespace dgl {
namespace network {

using std::string;

MessageQueue::MessageQueue(int64_t queue_size, int num_producers) {
18
19
  CHECK_GE(queue_size, 0);
  CHECK_GE(num_producers, 0);
20
21
22
23
24
  queue_size_ = queue_size;
  free_size_ = queue_size;
  num_producers_ = num_producers;
}

25
STATUS MessageQueue::Add(Message msg, bool is_blocking) {
26
  // check if message is too long to fit into the queue
27
28
29
  if (msg.size > queue_size_) {
    LOG(WARNING) << "Message is larger than the queue.";
    return MSG_GT_SIZE;
30
  }
31
32
33
  if (msg.size <= 0) {
    LOG(WARNING) << "Message size (" << msg.size << ") is negative or zero.";
    return MSG_LE_ZERO;
34
35
36
  }
  std::unique_lock<std::mutex> lock(mutex_);
  if (finished_producers_.size() >= num_producers_) {
37
    return QUEUE_CLOSE;
38
  }
39
40
  if (msg.size > free_size_ && !is_blocking) {
    return QUEUE_FULL;
41
  }
42
  cond_not_full_.wait(lock, [&]() { return msg.size <= free_size_; });
43
44
45
  // Add data pointer to queue
  queue_.push(msg);
  free_size_ -= msg.size;
46
47
48
  // not empty signal
  cond_not_empty_.notify_one();

49
  return ADD_SUCCESS;
50
51
}

52
STATUS MessageQueue::Remove(Message* msg, bool is_blocking) {
53
  std::unique_lock<std::mutex> lock(mutex_);
54
  if (queue_.empty()) {
55
    if (!is_blocking) {
56
      return QUEUE_EMPTY;
57
58
    }
    if (finished_producers_.size() >= num_producers_) {
59
      return QUEUE_CLOSE;
60
61
62
    }
  }

63
64
  cond_not_empty_.wait(
      lock, [this] { return !queue_.empty() || exit_flag_.load(); });
65
66
  if (finished_producers_.size() >= num_producers_ && queue_.empty()) {
    return QUEUE_CLOSE;
67
68
  }

69
  Message old_msg = queue_.front();
70
71
72
  queue_.pop();
  msg->data = old_msg.data;
  msg->size = old_msg.size;
73
  msg->receiver_id = old_msg.receiver_id;
74
75
  msg->deallocator = old_msg.deallocator;
  free_size_ += old_msg.size;
76
77
  cond_not_full_.notify_one();

78
  return REMOVE_SUCCESS;
79
80
}

81
void MessageQueue::SignalFinished(int producer_id) {
82
83
84
85
86
87
88
89
90
91
  std::lock_guard<std::mutex> lock(mutex_);
  finished_producers_.insert(producer_id);
  // if all producers have finished, consumers should be
  // waken up to get this signal
  if (finished_producers_.size() >= num_producers_) {
    exit_flag_.store(true);
    cond_not_empty_.notify_all();
  }
}

92
93
94
95
96
bool MessageQueue::Empty() const {
  std::lock_guard<std::mutex> lock(mutex_);
  return queue_.size() == 0;
}

97
98
bool MessageQueue::EmptyAndNoMoreAdd() const {
  std::lock_guard<std::mutex> lock(mutex_);
99
  return queue_.size() == 0 && finished_producers_.size() >= num_producers_;
100
101
102
103
}

}  // namespace network
}  // namespace dgl