tree_learner.cpp 1.39 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
#include <LightGBM/tree_learner.h>

#include "serial_tree_learner.h"
4
#include "gpu_tree_learner.h"
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include "parallel_tree_learner.h"

namespace LightGBM {

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type, const TreeConfig* tree_config) {
  if (device_type == std::string("cpu")) {
    if (learner_type == std::string("serial")) {
      return new SerialTreeLearner(tree_config);
    } else if (learner_type == std::string("feature")) {
      return new FeatureParallelTreeLearner<SerialTreeLearner>(tree_config);
    } else if (learner_type == std::string("data")) {
      return new DataParallelTreeLearner<SerialTreeLearner>(tree_config);
    } else if (learner_type == std::string("voting")) {
      return new VotingParallelTreeLearner<SerialTreeLearner>(tree_config);
    }
  }
  else if (device_type == std::string("gpu")) {
    if (learner_type == std::string("serial")) {
      return new GPUTreeLearner(tree_config);
    } else if (learner_type == std::string("feature")) {
      return new FeatureParallelTreeLearner<GPUTreeLearner>(tree_config);
    } else if (learner_type == std::string("data")) {
      return new DataParallelTreeLearner<GPUTreeLearner>(tree_config);
    } else if (learner_type == std::string("voting")) {
      return new VotingParallelTreeLearner<GPUTreeLearner>(tree_config);
    }
Guolin Ke's avatar
Guolin Ke committed
31
32
33
34
35
  }
  return nullptr;
}

}  // namespace LightGBM