tree_learner.cpp 1.34 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
#include <LightGBM/tree_learner.h>

#include "serial_tree_learner.h"
4
#include "gpu_tree_learner.h"
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include "parallel_tree_learner.h"

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
9
TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type, const Config* config) {
10
11
  if (device_type == std::string("cpu")) {
    if (learner_type == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
12
      return new SerialTreeLearner(config);
13
    } else if (learner_type == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
14
      return new FeatureParallelTreeLearner<SerialTreeLearner>(config);
15
    } else if (learner_type == std::string("data")) {
Guolin Ke's avatar
Guolin Ke committed
16
      return new DataParallelTreeLearner<SerialTreeLearner>(config);
17
    } else if (learner_type == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
18
      return new VotingParallelTreeLearner<SerialTreeLearner>(config);
19
    }
20
  } else if (device_type == std::string("gpu")) {
21
    if (learner_type == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
22
      return new GPUTreeLearner(config);
23
    } else if (learner_type == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
24
      return new FeatureParallelTreeLearner<GPUTreeLearner>(config);
25
    } else if (learner_type == std::string("data")) {
Guolin Ke's avatar
Guolin Ke committed
26
      return new DataParallelTreeLearner<GPUTreeLearner>(config);
27
    } else if (learner_type == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
28
      return new VotingParallelTreeLearner<GPUTreeLearner>(config);
29
    }
Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
  }
  return nullptr;
}

}  // namespace LightGBM