parallel_tree_learner.h 2.98 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
#ifndef LIGHTGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_
#define LIGHTGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_

#include <LightGBM/utils/array_args.h>

#include <LightGBM/network.h>
#include "serial_tree_learner.h"

#include <cstring>

#include <vector>
Guolin Ke's avatar
Guolin Ke committed
12
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
13
14
15
16
17

namespace LightGBM {

/*!
* \brief Feature parallel learning algorithm.
Qiwei Ye's avatar
Qiwei Ye committed
18
19
*        Different machine will find best split on different features, then sync global best split
*        It is recommonded used when #data is small or #feature is large
Guolin Ke's avatar
Guolin Ke committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
*/
class FeatureParallelTreeLearner: public SerialTreeLearner {
public:
  explicit FeatureParallelTreeLearner(const TreeConfig& tree_config);
  ~FeatureParallelTreeLearner();
  virtual void Init(const Dataset* train_data);

protected:
  void BeforeTrain() override;
  void FindBestSplitsForLeaves() override;
private:
  /*! \brief rank of local machine */
  int rank_;
  /*! \brief Number of machines of this parallel task */
  int num_machines_;
  /*! \brief Buffer for network send */
Guolin Ke's avatar
Guolin Ke committed
36
  std::vector<char> input_buffer_;
Guolin Ke's avatar
Guolin Ke committed
37
  /*! \brief Buffer for network receive */
Guolin Ke's avatar
Guolin Ke committed
38
  std::vector<char> output_buffer_;
Guolin Ke's avatar
Guolin Ke committed
39
40
41
42
};

/*!
* \brief Data parallel learning algorithm.
Qiwei Ye's avatar
Qiwei Ye committed
43
44
*        Workers use local data to construct histograms locally, then sync up global histograms.
*        It is recommonded used when #data is large or #feature is small
Guolin Ke's avatar
Guolin Ke committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
*/
class DataParallelTreeLearner: public SerialTreeLearner {
public:
  explicit DataParallelTreeLearner(const TreeConfig& tree_config);
  ~DataParallelTreeLearner();
  void Init(const Dataset* train_data) override;
protected:
  void BeforeTrain() override;
  void FindBestThresholds() override;
  void FindBestSplitsForLeaves() override;
  void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;

  inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
    if (leaf_idx >= 0) {
      return global_data_count_in_leaf_[leaf_idx];
    } else {
      return 0;
    }
  }

private:
  /*! \brief Rank of local machine */
  int rank_;
  /*! \brief Number of machines of this parallel task */
  int num_machines_;
  /*! \brief Buffer for network send */
Guolin Ke's avatar
Guolin Ke committed
71
  std::vector<char> input_buffer_;
Guolin Ke's avatar
Guolin Ke committed
72
  /*! \brief Buffer for network receive */
Guolin Ke's avatar
Guolin Ke committed
73
  std::vector<char> output_buffer_;
Guolin Ke's avatar
Guolin Ke committed
74
75
  /*! \brief different machines will aggregate histograms for different features,
       use this to mark local aggregate features*/
Guolin Ke's avatar
Guolin Ke committed
76
  std::vector<bool> is_feature_aggregated_;
Guolin Ke's avatar
Guolin Ke committed
77
  /*! \brief Block start index for reduce scatter */
Guolin Ke's avatar
Guolin Ke committed
78
  std::vector<int> block_start_;
Guolin Ke's avatar
Guolin Ke committed
79
  /*! \brief Block size for reduce scatter */
Guolin Ke's avatar
Guolin Ke committed
80
  std::vector<int> block_len_;
Hui Xue's avatar
Hui Xue committed
81
  /*! \brief Write positions for feature histograms */
Guolin Ke's avatar
Guolin Ke committed
82
  std::vector<int> buffer_write_start_pos_;
Hui Xue's avatar
Hui Xue committed
83
  /*! \brief Read positions for local feature histograms */
Guolin Ke's avatar
Guolin Ke committed
84
  std::vector<int> buffer_read_start_pos_;
Guolin Ke's avatar
Guolin Ke committed
85
86
87
  /*! \brief Size for reduce scatter */
  int reduce_scatter_size_;
  /*! \brief Store global number of data in leaves  */
Guolin Ke's avatar
Guolin Ke committed
88
  std::vector<data_size_t> global_data_count_in_leaf_;
Guolin Ke's avatar
Guolin Ke committed
89
90
91
92
};


}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
93
#endif   // LightGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_
Guolin Ke's avatar
Guolin Ke committed
94