tree_learner.cpp 2.33 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/tree_learner.h>

7
8
#include <string>

9
#include "gpu_tree_learner.h"
10
#include "linear_tree_learner.h"
Guolin Ke's avatar
Guolin Ke committed
11
#include "parallel_tree_learner.h"
12
#include "serial_tree_learner.h"
13
#include "cuda/cuda_single_gpu_tree_learner.hpp"
Guolin Ke's avatar
Guolin Ke committed
14
15
16

namespace LightGBM {

17
TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type,
18
                                            const Config* config, const bool boosting_on_cuda) {
19
20
  if (device_type == std::string("cpu")) {
    if (learner_type == std::string("serial")) {
21
      if (config->linear_tree) {
22
        return new LinearTreeLearner<SerialTreeLearner>(config);
23
24
25
      } else {
        return new SerialTreeLearner(config);
      }
26
    } else if (learner_type == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
27
      return new FeatureParallelTreeLearner<SerialTreeLearner>(config);
28
    } else if (learner_type == std::string("data")) {
Guolin Ke's avatar
Guolin Ke committed
29
      return new DataParallelTreeLearner<SerialTreeLearner>(config);
30
    } else if (learner_type == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
31
      return new VotingParallelTreeLearner<SerialTreeLearner>(config);
32
    }
33
  } else if (device_type == std::string("gpu")) {
34
    if (learner_type == std::string("serial")) {
35
36
37
38
39
      if (config->linear_tree) {
        return new LinearTreeLearner<GPUTreeLearner>(config);
      } else {
        return new GPUTreeLearner(config);
      }
40
    } else if (learner_type == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
41
      return new FeatureParallelTreeLearner<GPUTreeLearner>(config);
42
    } else if (learner_type == std::string("data")) {
Guolin Ke's avatar
Guolin Ke committed
43
      return new DataParallelTreeLearner<GPUTreeLearner>(config);
44
    } else if (learner_type == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
45
      return new VotingParallelTreeLearner<GPUTreeLearner>(config);
46
    }
47
  } else if (device_type == std::string("cuda")) {
48
49
    if (learner_type == std::string("serial")) {
      if (config->num_gpu == 1) {
50
        return new CUDASingleGPUTreeLearner(config, boosting_on_cuda);
51
      } else {
52
        Log::Fatal("Currently cuda version only supports training on a single GPU.");
53
54
      }
    } else {
55
      Log::Fatal("Currently cuda version only supports training on a single machine.");
56
    }
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
61
  }
  return nullptr;
}

}  // namespace LightGBM