tree_learner.cpp 1.51 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/tree_learner.h>

7
#include "gpu_tree_learner.h"
Guolin Ke's avatar
Guolin Ke committed
8
#include "parallel_tree_learner.h"
9
#include "serial_tree_learner.h"
Guolin Ke's avatar
Guolin Ke committed
10
11
12

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
13
TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type, const Config* config) {
14
15
  if (device_type == std::string("cpu")) {
    if (learner_type == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
16
      return new SerialTreeLearner(config);
17
    } else if (learner_type == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
18
      return new FeatureParallelTreeLearner<SerialTreeLearner>(config);
19
    } else if (learner_type == std::string("data")) {
Guolin Ke's avatar
Guolin Ke committed
20
      return new DataParallelTreeLearner<SerialTreeLearner>(config);
21
    } else if (learner_type == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
22
      return new VotingParallelTreeLearner<SerialTreeLearner>(config);
23
    }
24
  } else if (device_type == std::string("gpu")) {
25
    if (learner_type == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
26
      return new GPUTreeLearner(config);
27
    } else if (learner_type == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
28
      return new FeatureParallelTreeLearner<GPUTreeLearner>(config);
29
    } else if (learner_type == std::string("data")) {
Guolin Ke's avatar
Guolin Ke committed
30
      return new DataParallelTreeLearner<GPUTreeLearner>(config);
31
    } else if (learner_type == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
32
      return new VotingParallelTreeLearner<GPUTreeLearner>(config);
33
    }
Guolin Ke's avatar
Guolin Ke committed
34
35
36
37
38
  }
  return nullptr;
}

}  // namespace LightGBM