parallel_tree_learner.h 2.85 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#ifndef LIGHTGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_
#define LIGHTGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_

#include <LightGBM/utils/array_args.h>

#include <LightGBM/network.h>
#include "serial_tree_learner.h"

#include <cstring>

#include <vector>

namespace LightGBM {

/*!
* \brief Feature parallel learning algorithm.
* Different machine will find best split on different features, then sync global best split
* When #data is small or #feature is large, you can use this to have better speed-up
*/
class FeatureParallelTreeLearner: public SerialTreeLearner {
public:
  explicit FeatureParallelTreeLearner(const TreeConfig& tree_config);
  ~FeatureParallelTreeLearner();
  virtual void Init(const Dataset* train_data);

protected:
  void BeforeTrain() override;
  void FindBestSplitsForLeaves() override;
private:
  /*! \brief rank of local machine */
  int rank_;
  /*! \brief Number of machines of this parallel task */
  int num_machines_;
  /*! \brief Buffer for network send */
  char* input_buffer_;
  /*! \brief Buffer for network receive */
  char* output_buffer_;
};

/*!
* \brief Data parallel learning algorithm.
* Workers use local data to construct histograms locally, then sync up global histograms.
* When #data is large or #feature is small, you can use this to have better speed-up
*/
class DataParallelTreeLearner: public SerialTreeLearner {
public:
  explicit DataParallelTreeLearner(const TreeConfig& tree_config);
  ~DataParallelTreeLearner();
  void Init(const Dataset* train_data) override;
protected:
  void BeforeTrain() override;
  void FindBestThresholds() override;
  void FindBestSplitsForLeaves() override;
  void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;

  inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
    if (leaf_idx >= 0) {
      return global_data_count_in_leaf_[leaf_idx];
    } else {
      return 0;
    }
  }

private:
  /*! \brief Rank of local machine */
  int rank_;
  /*! \brief Number of machines of this parallel task */
  int num_machines_;
  /*! \brief Buffer for network send */
  char* input_buffer_;
  /*! \brief Buffer for network receive */
  char* output_buffer_;
  /*! \brief different machines will aggregate histograms for different features,
       use this to mark local aggregate features*/
  bool* is_feature_aggregated_;
  /*! \brief Block start index for reduce scatter */
  int* block_start_;
  /*! \brief Block size for reduce scatter */
  int* block_len_;
  /*! \brief Write positions for feature histgrams */
  int* buffer_write_start_pos_;
  /*! \brief Read positions for local feature histgrams */
  int* buffer_read_start_pos_;
  /*! \brief Size for reduce scatter */
  int reduce_scatter_size_;
  /*! \brief Store global number of data in leaves  */
  data_size_t* global_data_count_in_leaf_;
};


}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
92
#endif   // LightGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_
Guolin Ke's avatar
Guolin Ke committed
93