parallel_tree_learner.h 8.14 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
#ifndef LIGHTGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_
#define LIGHTGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_

Guolin Ke's avatar
Guolin Ke committed
8
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
9
#include <memory>
10
11
#include <vector>

12
13
14
#include <LightGBM/network.h>
#include <LightGBM/utils/array_args.h>

15
16
#include "gpu_tree_learner.h"
#include "serial_tree_learner.h"
Guolin Ke's avatar
Guolin Ke committed
17
18
19
20
21

namespace LightGBM {

/*!
* \brief Feature parallel learning algorithm.
Qiwei Ye's avatar
Qiwei Ye committed
22
23
*        Different machine will find best split on different features, then sync global best split
*        It is recommonded used when #data is small or #feature is large
Guolin Ke's avatar
Guolin Ke committed
24
*/
25
26
template <typename TREELEARNER_T>
class FeatureParallelTreeLearner: public TREELEARNER_T {
Nikita Titov's avatar
Nikita Titov committed
27
 public:
Guolin Ke's avatar
Guolin Ke committed
28
  explicit FeatureParallelTreeLearner(const Config* config);
Guolin Ke's avatar
Guolin Ke committed
29
  ~FeatureParallelTreeLearner();
30
  void Init(const Dataset* train_data, bool is_constant_hessian) override;
Guolin Ke's avatar
Guolin Ke committed
31

Nikita Titov's avatar
Nikita Titov committed
32
 protected:
Guolin Ke's avatar
Guolin Ke committed
33
  void BeforeTrain() override;
Guolin Ke's avatar
Guolin Ke committed
34
  void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
Nikita Titov's avatar
Nikita Titov committed
35
36

 private:
Guolin Ke's avatar
Guolin Ke committed
37
38
39
40
41
  /*! \brief rank of local machine */
  int rank_;
  /*! \brief Number of machines of this parallel task */
  int num_machines_;
  /*! \brief Buffer for network send */
Guolin Ke's avatar
Guolin Ke committed
42
  std::vector<char> input_buffer_;
Guolin Ke's avatar
Guolin Ke committed
43
  /*! \brief Buffer for network receive */
Guolin Ke's avatar
Guolin Ke committed
44
  std::vector<char> output_buffer_;
Guolin Ke's avatar
Guolin Ke committed
45
46
47
48
};

/*!
* \brief Data parallel learning algorithm.
Qiwei Ye's avatar
Qiwei Ye committed
49
50
*        Workers use local data to construct histograms locally, then sync up global histograms.
*        It is recommonded used when #data is large or #feature is small
Guolin Ke's avatar
Guolin Ke committed
51
*/
52
53
template <typename TREELEARNER_T>
class DataParallelTreeLearner: public TREELEARNER_T {
Nikita Titov's avatar
Nikita Titov committed
54
 public:
Guolin Ke's avatar
Guolin Ke committed
55
  explicit DataParallelTreeLearner(const Config* config);
Guolin Ke's avatar
Guolin Ke committed
56
  ~DataParallelTreeLearner();
57
  void Init(const Dataset* train_data, bool is_constant_hessian) override;
Guolin Ke's avatar
Guolin Ke committed
58
  void ResetConfig(const Config* config) override;
59

Nikita Titov's avatar
Nikita Titov committed
60
 protected:
Guolin Ke's avatar
Guolin Ke committed
61
  void BeforeTrain() override;
Guolin Ke's avatar
Guolin Ke committed
62
63
  void FindBestSplits() override;
  void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
Guolin Ke's avatar
Guolin Ke committed
64
65
66
67
68
69
70
71
72
73
  void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;

  inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
    if (leaf_idx >= 0) {
      return global_data_count_in_leaf_[leaf_idx];
    } else {
      return 0;
    }
  }

Nikita Titov's avatar
Nikita Titov committed
74
 private:
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
  /*! \brief Rank of local machine */
  int rank_;
  /*! \brief Number of machines of this parallel task */
  int num_machines_;
  /*! \brief Buffer for network send */
Guolin Ke's avatar
Guolin Ke committed
80
  std::vector<char> input_buffer_;
Guolin Ke's avatar
Guolin Ke committed
81
  /*! \brief Buffer for network receive */
Guolin Ke's avatar
Guolin Ke committed
82
  std::vector<char> output_buffer_;
Guolin Ke's avatar
Guolin Ke committed
83
84
  /*! \brief different machines will aggregate histograms for different features,
       use this to mark local aggregate features*/
Guolin Ke's avatar
Guolin Ke committed
85
  std::vector<bool> is_feature_aggregated_;
Guolin Ke's avatar
Guolin Ke committed
86
  /*! \brief Block start index for reduce scatter */
Guolin Ke's avatar
Guolin Ke committed
87
  std::vector<comm_size_t> block_start_;
Guolin Ke's avatar
Guolin Ke committed
88
  /*! \brief Block size for reduce scatter */
Guolin Ke's avatar
Guolin Ke committed
89
  std::vector<comm_size_t> block_len_;
Hui Xue's avatar
Hui Xue committed
90
  /*! \brief Write positions for feature histograms */
Guolin Ke's avatar
Guolin Ke committed
91
  std::vector<comm_size_t> buffer_write_start_pos_;
Hui Xue's avatar
Hui Xue committed
92
  /*! \brief Read positions for local feature histograms */
Guolin Ke's avatar
Guolin Ke committed
93
  std::vector<comm_size_t> buffer_read_start_pos_;
Guolin Ke's avatar
Guolin Ke committed
94
  /*! \brief Size for reduce scatter */
Guolin Ke's avatar
Guolin Ke committed
95
  comm_size_t reduce_scatter_size_;
Guolin Ke's avatar
Guolin Ke committed
96
  /*! \brief Store global number of data in leaves  */
Guolin Ke's avatar
Guolin Ke committed
97
  std::vector<data_size_t> global_data_count_in_leaf_;
Guolin Ke's avatar
Guolin Ke committed
98
99
};

Guolin Ke's avatar
Guolin Ke committed
100
101
102
103
104
105
/*!
* \brief Voting based data parallel learning algorithm.
* Like data parallel, but not aggregate histograms for all features.
* Here using voting to reduce features, and only aggregate histograms for selected features.
* When #data is large and #feature is large, you can use this to have better speed-up
*/
106
107
template <typename TREELEARNER_T>
class VotingParallelTreeLearner: public TREELEARNER_T {
Nikita Titov's avatar
Nikita Titov committed
108
 public:
Guolin Ke's avatar
Guolin Ke committed
109
  explicit VotingParallelTreeLearner(const Config* config);
Guolin Ke's avatar
Guolin Ke committed
110
  ~VotingParallelTreeLearner() { }
111
  void Init(const Dataset* train_data, bool is_constant_hessian) override;
Guolin Ke's avatar
Guolin Ke committed
112
  void ResetConfig(const Config* config) override;
113

Nikita Titov's avatar
Nikita Titov committed
114
 protected:
Guolin Ke's avatar
Guolin Ke committed
115
  void BeforeTrain() override;
Guolin Ke's avatar
Guolin Ke committed
116
  bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
Guolin Ke's avatar
Guolin Ke committed
117
118
  void FindBestSplits() override;
  void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
Guolin Ke's avatar
Guolin Ke committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
  void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;

  inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
    if (leaf_idx >= 0) {
      return global_data_count_in_leaf_[leaf_idx];
    } else {
      return 0;
    }
  }
  /*!
  * \brief Perform global voting
  * \param leaf_idx index of leaf
  * \param splits All splits from local voting
  * \param out Result of gobal voting, only store feature indices
  */
134
  void GlobalVoting(int leaf_idx, const std::vector<LightSplitInfo>& splits,
Guolin Ke's avatar
Guolin Ke committed
135
136
137
138
139
140
141
142
143
    std::vector<int>* out);
  /*!
  * \brief Copy local histgram to buffer
  * \param smaller_top_features Selected features for smaller leaf
  * \param larger_top_features Selected features for larger leaf
  */
  void CopyLocalHistogram(const std::vector<int>& smaller_top_features,
    const std::vector<int>& larger_top_features);

Nikita Titov's avatar
Nikita Titov committed
144
 private:
Guolin Ke's avatar
Guolin Ke committed
145
  /*! \brief Tree config used in local mode */
Guolin Ke's avatar
Guolin Ke committed
146
  Config local_config_;
Guolin Ke's avatar
Guolin Ke committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
  /*! \brief Voting size */
  int top_k_;
  /*! \brief Rank of local machine*/
  int rank_;
  /*! \brief Number of machines */
  int num_machines_;
  /*! \brief Buffer for network send */
  std::vector<char> input_buffer_;
  /*! \brief Buffer for network receive */
  std::vector<char> output_buffer_;
  /*! \brief different machines will aggregate histograms for different features,
  use this to mark local aggregate features*/
  std::vector<bool> smaller_is_feature_aggregated_;
  /*! \brief different machines will aggregate histograms for different features,
  use this to mark local aggregate features*/
  std::vector<bool> larger_is_feature_aggregated_;
  /*! \brief Block start index for reduce scatter */
Guolin Ke's avatar
Guolin Ke committed
164
  std::vector<comm_size_t> block_start_;
Guolin Ke's avatar
Guolin Ke committed
165
  /*! \brief Block size for reduce scatter */
Guolin Ke's avatar
Guolin Ke committed
166
  std::vector<comm_size_t> block_len_;
Guolin Ke's avatar
Guolin Ke committed
167
  /*! \brief Read positions for feature histgrams at smaller leaf */
Guolin Ke's avatar
Guolin Ke committed
168
  std::vector<comm_size_t> smaller_buffer_read_start_pos_;
Guolin Ke's avatar
Guolin Ke committed
169
  /*! \brief Read positions for feature histgrams at larger leaf */
Guolin Ke's avatar
Guolin Ke committed
170
  std::vector<comm_size_t> larger_buffer_read_start_pos_;
Guolin Ke's avatar
Guolin Ke committed
171
  /*! \brief Size for reduce scatter */
Guolin Ke's avatar
Guolin Ke committed
172
  comm_size_t reduce_scatter_size_;
Guolin Ke's avatar
Guolin Ke committed
173
174
175
176
177
178
179
180
181
182
  /*! \brief Store global number of data in leaves  */
  std::vector<data_size_t> global_data_count_in_leaf_;
  /*! \brief Store global split information for smaller leaf  */
  std::unique_ptr<LeafSplits> smaller_leaf_splits_global_;
  /*! \brief Store global split information for larger leaf  */
  std::unique_ptr<LeafSplits> larger_leaf_splits_global_;
  /*! \brief Store global histogram for smaller leaf  */
  std::unique_ptr<FeatureHistogram[]> smaller_leaf_histogram_array_global_;
  /*! \brief Store global histogram for larger leaf  */
  std::unique_ptr<FeatureHistogram[]> larger_leaf_histogram_array_global_;
Guolin Ke's avatar
Guolin Ke committed
183

184
185
  std::vector<hist_t> smaller_leaf_histogram_data_;
  std::vector<hist_t> larger_leaf_histogram_data_;
Guolin Ke's avatar
Guolin Ke committed
186
  std::vector<FeatureMetainfo> feature_metas_;
Guolin Ke's avatar
Guolin Ke committed
187
};
Guolin Ke's avatar
Guolin Ke committed
188

189
// To-do: reduce the communication cost by using bitset to communicate.
190
inline void SyncUpGlobalBestSplit(char* input_buffer_, char* output_buffer_, SplitInfo* smaller_best_split, SplitInfo* larger_best_split, int max_cat_threshold) {
191
  // sync global best info
192
  int size = SplitInfo::Size(max_cat_threshold);
193
194
  smaller_best_split->CopyTo(input_buffer_);
  larger_best_split->CopyTo(input_buffer_ + size);
195
  Network::Allreduce(input_buffer_, size * 2, size, output_buffer_,
Guolin Ke's avatar
Guolin Ke committed
196
197
                     [] (const char* src, char* dst, int size, comm_size_t len) {
    comm_size_t used_size = 0;
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    LightSplitInfo p1, p2;
    while (used_size < len) {
      p1.CopyFrom(src);
      p2.CopyFrom(dst);
      if (p1 > p2) {
        std::memcpy(dst, src, size);
      }
      src += size;
      dst += size;
      used_size += size;
    }
  });
  // copy back
  smaller_best_split->CopyFrom(output_buffer_);
  larger_best_split->CopyFrom(output_buffer_ + size);
}

Guolin Ke's avatar
Guolin Ke committed
215
}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
216
#endif   // LightGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_