data_partition.hpp 7.2 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
#ifndef LIGHTGBM_TREELEARNER_DATA_PARTITION_HPP_
#define LIGHTGBM_TREELEARNER_DATA_PARTITION_HPP_

#include <LightGBM/meta.h>
#include <LightGBM/feature.h>

#include <omp.h>

#include <cstring>

#include <vector>

namespace LightGBM {
/*!
* \brief DataPartition is used to store the the partition of data on tree.
*/
class DataPartition {
public:
  DataPartition(data_size_t num_data, int num_leafs)
    :num_data_(num_data), num_leaves_(num_leafs) {
    leaf_begin_ = new data_size_t[num_leaves_];
    leaf_count_ = new data_size_t[num_leaves_];
    indices_ = new data_size_t[num_data_];
    temp_left_indices_ = new data_size_t[num_data_];
    temp_right_indices_ = new data_size_t[num_data_];
    used_data_indices_ = nullptr;
#pragma omp parallel
#pragma omp master
    {
      num_threads_ = omp_get_num_threads();
    }
    offsets_buf_ = new data_size_t[num_threads_];
    left_cnts_buf_ = new data_size_t[num_threads_];
    right_cnts_buf_ = new data_size_t[num_threads_];
    left_write_pos_buf_ = new data_size_t[num_threads_];
    right_write_pos_buf_ = new data_size_t[num_threads_];
  }
  ~DataPartition() {
    delete[] leaf_begin_;
    delete[] leaf_count_;
    delete[] indices_;
    delete[] temp_left_indices_;
    delete[] temp_right_indices_;
    delete[] offsets_buf_;
    delete[] left_cnts_buf_;
    delete[] right_cnts_buf_;
    delete[] left_write_pos_buf_;
    delete[] right_write_pos_buf_;
  }

  /*!
  * \brief Init, will put all data on the root(leaf_idx = 0)
  */
  void Init() {
    for (int i = 0; i < num_leaves_; ++i) {
      leaf_count_[i] = 0;
    }
    leaf_begin_[0] = 0;
    if (used_data_indices_ == nullptr) {
      // if using all data
      leaf_count_[0] = num_data_;
#pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data_; ++i) {
        indices_[i] = i;
      }
    } else {
      // if bagging
      leaf_count_[0] = used_data_count_;
      std::memcpy(indices_, used_data_indices_, used_data_count_ * sizeof(data_size_t));
    }
  }

  /*!
  * \brief Get the data indices of one leaf
  * \param leaf index of leaf
  * \param indices output data indices
  * \return number of data on this leaf
  */
  data_size_t GetIndexOnLeaf(int leaf, data_size_t** indices) const {
    // copy reference, maybe unsafe, but faster
    data_size_t begin = leaf_begin_[leaf];
    (*indices) = static_cast<data_size_t*>(indices_ + begin);
    return leaf_count_[leaf];
  }

  /*!
  * \brief Split the data
  * \param leaf index of leaf
  * \param feature_bins feature bin data
  * \param threshold threshold that want to split
  * \param right_leaf index of right leaf
  */
  void Split(int leaf, const Bin* feature_bins, unsigned int threshold, int right_leaf) {
    const data_size_t min_inner_size = 1000;
    // get leaf boundary
    const data_size_t begin = leaf_begin_[leaf];
    const data_size_t cnt = leaf_count_[leaf];

    data_size_t inner_size = (cnt + num_threads_ - 1) / num_threads_;
    if (inner_size < min_inner_size) { inner_size = min_inner_size; }
    // split data multi-threading
#pragma omp parallel for schedule(static, 1)
    for (int i = 0; i < num_threads_; ++i) {
      left_cnts_buf_[i] = 0;
      right_cnts_buf_[i] = 0;
      data_size_t cur_start = i * inner_size;
      if (cur_start > cnt) { continue; }
      data_size_t cur_cnt = inner_size;
      if (cur_start + cur_cnt > cnt) { cur_cnt = cnt - cur_start; }
      // split data inner, reduce the times of function called
      data_size_t cur_left_count = feature_bins->Split(threshold, indices_ + begin + cur_start, cur_cnt,
        temp_left_indices_ + cur_start, temp_right_indices_ + cur_start);
      offsets_buf_[i] = cur_start;
      left_cnts_buf_[i] = cur_left_count;
      right_cnts_buf_[i] = cur_cnt - cur_left_count;
    }
    data_size_t left_cnt = 0;
    left_write_pos_buf_[0] = 0;
    right_write_pos_buf_[0] = 0;
    for (int i = 1; i < num_threads_; ++i) {
      left_write_pos_buf_[i] = left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1];
      right_write_pos_buf_[i] = right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1];
    }
    left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1];
    // copy back indices of right leaf to indices_
#pragma omp parallel for schedule(static, 1)
    for (int i = 0; i < num_threads_; ++i) {
      if (left_cnts_buf_[i] > 0) {
        std::memcpy(indices_ + begin + left_write_pos_buf_[i], temp_left_indices_ + offsets_buf_[i], left_cnts_buf_[i] * sizeof(data_size_t));
      }
      if (right_cnts_buf_[i] > 0) {
        std::memcpy(indices_ + begin + left_cnt + right_write_pos_buf_[i], temp_right_indices_ + offsets_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t));
      }
    }
    // update leaf boundary
    leaf_count_[leaf] = left_cnt;
    leaf_begin_[right_leaf] = left_cnt + begin;
    leaf_count_[right_leaf] = cnt - left_cnt;
  }

  /*!
  * \brief SetLabelAt used data indices before training, used for bagging
  * \param used_data_indices indices of used data
  * \param num_used_data number of used data
  */
  void SetUsedDataIndices(const data_size_t * used_data_indices, data_size_t num_used_data) {
    used_data_indices_ = used_data_indices;
    used_data_count_ = num_used_data;
  }

  /*!
  * \brief Get number of data on one leaf
  * \param leaf index of leaf
  * \return number of data of this leaf
  */
  data_size_t leaf_count(int leaf) const { return leaf_count_[leaf]; }

  /*!
  * \brief Get leaf begin
  * \param leaf index of leaf
  * \return begin index of this leaf
  */
  data_size_t leaf_begin(int leaf) const { return leaf_begin_[leaf]; }

  const data_size_t* indices() const { return indices_; }

  /*! \brief Get number of leaves */
  int num_leaves() const { return num_leaves_; }

private:
  /*! \brief Number of all data */
  data_size_t num_data_;
  /*! \brief Number of all leaves */
  int num_leaves_;
  /*! \brief start index of data on one leaf */
  data_size_t* leaf_begin_;
  /*! \brief number of data on one leaf */
  data_size_t* leaf_count_;
  /*! \brief Store all data's indices, order by leaf[data_in_leaf0,..,data_leaf1,..] */
  data_size_t* indices_;
  /*! \brief team indices buffer for split */
  data_size_t* temp_left_indices_;
  /*! \brief team indices buffer for split */
  data_size_t* temp_right_indices_;
  /*! \brief used data indices, used for bagging */
  const data_size_t* used_data_indices_;
  /*! \brief used data count, used for bagging */
  data_size_t used_data_count_;
  /*! \brief number of threads */
  int num_threads_;
  /*! \brief Buffer for multi-threading data partition, used to store offset for different threads */
  data_size_t* offsets_buf_;
  /*! \brief Buffer for multi-threading data partition, used to store left count after split for different threads */
  data_size_t* left_cnts_buf_;
  /*! \brief Buffer for multi-threading data partition, used to store right count after split for different threads */
  data_size_t* right_cnts_buf_;
  /*! \brief Buffer for multi-threading data partition, used to store write position of left leaf for different threads */
  data_size_t* left_write_pos_buf_;
  /*! \brief Buffer for multi-threading data partition, used to store write position of right leaf for different threads */
  data_size_t* right_write_pos_buf_;
};

}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
204
#endif   // LightGBM_TREELEARNER_DATA_PARTITION_HPP_