tree_learner.cpp 1.35 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
#include <LightGBM/tree_learner.h>

#include "serial_tree_learner.h"
4
#include "gpu_tree_learner.h"
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include "parallel_tree_learner.h"

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
9
TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type, const Config* config) {
10
11
  if (device_type == std::string("cpu")) {
    if (learner_type == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
12
      return new SerialTreeLearner(config);
13
    } else if (learner_type == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
14
      return new FeatureParallelTreeLearner<SerialTreeLearner>(config);
15
    } else if (learner_type == std::string("data")) {
Guolin Ke's avatar
Guolin Ke committed
16
      return new DataParallelTreeLearner<SerialTreeLearner>(config);
17
    } else if (learner_type == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
18
      return new VotingParallelTreeLearner<SerialTreeLearner>(config);
19
20
21
22
    }
  }
  else if (device_type == std::string("gpu")) {
    if (learner_type == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
23
      return new GPUTreeLearner(config);
24
    } else if (learner_type == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
25
      return new FeatureParallelTreeLearner<GPUTreeLearner>(config);
26
    } else if (learner_type == std::string("data")) {
Guolin Ke's avatar
Guolin Ke committed
27
      return new DataParallelTreeLearner<GPUTreeLearner>(config);
28
    } else if (learner_type == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
29
      return new VotingParallelTreeLearner<GPUTreeLearner>(config);
30
    }
Guolin Ke's avatar
Guolin Ke committed
31
32
33
34
35
  }
  return nullptr;
}

}  // namespace LightGBM