lru_pool.h 3.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#ifndef LIGHTGBM_UTILS_LRU_POOL_H_
#define LIGHTGBM_UTILS_LRU_POOL_H_

#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/log.h>

#include <cstring>

namespace LightGBM {

/*!
* \brief A LRU cached object pool, used for store historical histograms
*/
template<typename T>
class LRUPool {
public:

  /*!
  * \brief Constructor
  */
  LRUPool() {
  }

  /*!
  * \brief Destructor
  */
  ~LRUPool() {
Guolin Ke's avatar
Guolin Ke committed
28
    FreeAll();
29
30
31
32
33
34
35
36
  }
  /*!
  * \brief Reset pool size
  * \param cache_size Max cache size
  * \param total_size Total size will be used
  */
  void ResetSize(int cache_size, int total_size) {
    // free old memory
Guolin Ke's avatar
Guolin Ke committed
37
    FreeAll();
38
39
40
41
    cache_size_ = cache_size;
    // at least need 2 bucket to store smaller leaf and larger leaf
    CHECK(cache_size_ >= 2);
    total_size_ = total_size;
42
43
44
45
46
47
48
49
50
51
52
    if (cache_size_ > total_size_) {
      cache_size_ = total_size_;
    }
    is_enough_ = (cache_size_ == total_size_);
    pool_ = new T[cache_size_];
    if (!is_enough_) {
      mapper_ = new int[total_size_];
      inverse_mapper_ = new int[cache_size_];
      last_used_time_ = new int[cache_size_];
      ResetMap();
    }
53
54
55
56
57
58
59
  }


  /*!
  * \brief Reset mapper
  */
  void ResetMap() {
60
61
62
63
64
65
    if (!is_enough_) {
      cur_time_ = 0;
      memset(mapper_, -1, sizeof(int)*total_size_);
      memset(inverse_mapper_, -1, sizeof(int)*cache_size_);
      memset(last_used_time_, 0, sizeof(int)*cache_size_);
    }
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
  }

  /*!
  * \brief Set data for the pool for specific index
  * \param idx which index want to set to
  * \param data
  */
  void Set(int idx, const T& data) {
    pool_[idx] = data;
  }

  /*!
  * \brief Get data for the specific index
  * \param idx which index want to get 
  * \param out output data will store into this
  * \return True if this index is in the pool, False if this index is not in the pool
  */
  bool Get(int idx, T* out) {
84
85
86
87
88
    if (is_enough_) {
      *out = pool_[idx];
      return true;
    }
    else if (mapper_[idx] >= 0) {
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
      int slot = mapper_[idx];
      *out = pool_[slot];
      last_used_time_[slot] = ++cur_time_;
      return true;
    } else {
      // choose the least used slot 
      int slot = static_cast<int>(ArrayArgs<int>::ArgMin(last_used_time_, cache_size_));
      *out = pool_[slot];
      last_used_time_[slot] = ++cur_time_;

      // reset previous mapper
      if (inverse_mapper_[slot] >= 0) mapper_[inverse_mapper_[slot]] = -1;

      // update current mapper
      mapper_[idx] = slot;
      inverse_mapper_[slot] = idx;
      return false;
    }
  }

  /*!
  * \brief Move data from one index to another index
  * \param src_idx 
  * \param dst_idx 
  */
  void Move(int src_idx, int dst_idx) {
115
116
117
118
    if (is_enough_) {
      std::swap(pool_[src_idx], pool_[dst_idx]);
      return;
    }
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    if (mapper_[src_idx] < 0) {
      return;
    }
    // get slot of src idx
    int slot = mapper_[src_idx];
    // reset src_idx
    mapper_[src_idx] = -1;

    // move to dst idx
    mapper_[dst_idx] = slot;
    last_used_time_[slot] = ++cur_time_;
    inverse_mapper_[slot] = dst_idx;
  }
private:
133

Guolin Ke's avatar
Guolin Ke committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
  void FreeAll(){
    if (pool_ != nullptr) {
      delete[] pool_;
    }
    if (mapper_ != nullptr) {
      delete[] mapper_;
    }
    if (inverse_mapper_ != nullptr) {
      delete[] inverse_mapper_;
    }
    if (last_used_time_ != nullptr) {
      delete[] last_used_time_;
    }
  }
148
149
150
  T* pool_ = nullptr;
  int cache_size_;
  int total_size_;
151
  bool is_enough_ = false;
152
153
154
155
156
157
158
159
160
  int* mapper_ = nullptr;
  int* inverse_mapper_ = nullptr;
  int* last_used_time_ = nullptr;
  int cur_time_ = 0;
};

}

#endif  // LIGHTGBM_UTILS_LRU_POOL_H_