parallel_tree_learner.h 2.85 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#ifndef LIGHTGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_
#define LIGHTGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_

#include <LightGBM/utils/array_args.h>

#include <LightGBM/network.h>
#include "serial_tree_learner.h"

#include <cstring>

#include <vector>

namespace LightGBM {

/*!
* \brief Feature parallel learning algorithm.
Qiwei Ye's avatar
Qiwei Ye committed
17
18
*        Different machine will find best split on different features, then sync global best split
*        It is recommonded used when #data is small or #feature is large
Guolin Ke's avatar
Guolin Ke committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
*/
class FeatureParallelTreeLearner: public SerialTreeLearner {
public:
  explicit FeatureParallelTreeLearner(const TreeConfig& tree_config);
  ~FeatureParallelTreeLearner();
  virtual void Init(const Dataset* train_data);

protected:
  void BeforeTrain() override;
  void FindBestSplitsForLeaves() override;
private:
  /*! \brief rank of local machine */
  int rank_;
  /*! \brief Number of machines of this parallel task */
  int num_machines_;
  /*! \brief Buffer for network send */
  char* input_buffer_;
  /*! \brief Buffer for network receive */
  char* output_buffer_;
};

/*!
* \brief Data parallel learning algorithm.
Qiwei Ye's avatar
Qiwei Ye committed
42
43
*        Workers use local data to construct histograms locally, then sync up global histograms.
*        It is recommonded used when #data is large or #feature is small
Guolin Ke's avatar
Guolin Ke committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
*/
class DataParallelTreeLearner: public SerialTreeLearner {
public:
  explicit DataParallelTreeLearner(const TreeConfig& tree_config);
  ~DataParallelTreeLearner();
  void Init(const Dataset* train_data) override;
protected:
  void BeforeTrain() override;
  void FindBestThresholds() override;
  void FindBestSplitsForLeaves() override;
  void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;

  inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
    if (leaf_idx >= 0) {
      return global_data_count_in_leaf_[leaf_idx];
    } else {
      return 0;
    }
  }

private:
  /*! \brief Rank of local machine */
  int rank_;
  /*! \brief Number of machines of this parallel task */
  int num_machines_;
  /*! \brief Buffer for network send */
  char* input_buffer_;
  /*! \brief Buffer for network receive */
  char* output_buffer_;
  /*! \brief different machines will aggregate histograms for different features,
       use this to mark local aggregate features*/
  bool* is_feature_aggregated_;
  /*! \brief Block start index for reduce scatter */
  int* block_start_;
  /*! \brief Block size for reduce scatter */
  int* block_len_;
Hui Xue's avatar
Hui Xue committed
80
  /*! \brief Write positions for feature histograms */
Guolin Ke's avatar
Guolin Ke committed
81
  int* buffer_write_start_pos_;
Hui Xue's avatar
Hui Xue committed
82
  /*! \brief Read positions for local feature histograms */
Guolin Ke's avatar
Guolin Ke committed
83
84
85
86
87
88
89
90
91
  int* buffer_read_start_pos_;
  /*! \brief Size for reduce scatter */
  int reduce_scatter_size_;
  /*! \brief Store global number of data in leaves  */
  data_size_t* global_data_count_in_leaf_;
};


}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
92
#endif   // LightGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_
Guolin Ke's avatar
Guolin Ke committed
93