tree_learner.cpp 2.65 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/tree_learner.h>

7
#include "cuda_tree_learner.h"
8
#include "gpu_tree_learner.h"
9
#include "linear_tree_learner.h"
Guolin Ke's avatar
Guolin Ke committed
10
#include "parallel_tree_learner.h"
11
#include "serial_tree_learner.h"
12
#include "cuda/cuda_single_gpu_tree_learner.hpp"
Guolin Ke's avatar
Guolin Ke committed
13
14
15

namespace LightGBM {

16
17
TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type,
                                            const Config* config) {
18
19
  if (device_type == std::string("cpu")) {
    if (learner_type == std::string("serial")) {
20
21
22
23
24
      if (config->linear_tree) {
        return new LinearTreeLearner(config);
      } else {
        return new SerialTreeLearner(config);
      }
25
    } else if (learner_type == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
26
      return new FeatureParallelTreeLearner<SerialTreeLearner>(config);
27
    } else if (learner_type == std::string("data")) {
Guolin Ke's avatar
Guolin Ke committed
28
      return new DataParallelTreeLearner<SerialTreeLearner>(config);
29
    } else if (learner_type == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
30
      return new VotingParallelTreeLearner<SerialTreeLearner>(config);
31
    }
32
  } else if (device_type == std::string("gpu")) {
33
    if (learner_type == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
34
      return new GPUTreeLearner(config);
35
    } else if (learner_type == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
36
      return new FeatureParallelTreeLearner<GPUTreeLearner>(config);
37
    } else if (learner_type == std::string("data")) {
Guolin Ke's avatar
Guolin Ke committed
38
      return new DataParallelTreeLearner<GPUTreeLearner>(config);
39
    } else if (learner_type == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
40
      return new VotingParallelTreeLearner<GPUTreeLearner>(config);
41
    }
42
43
44
45
46
47
48
49
50
51
  } else if (device_type == std::string("cuda")) {
    if (learner_type == std::string("serial")) {
      return new CUDATreeLearner(config);
    } else if (learner_type == std::string("feature")) {
      return new FeatureParallelTreeLearner<CUDATreeLearner>(config);
    } else if (learner_type == std::string("data")) {
      return new DataParallelTreeLearner<CUDATreeLearner>(config);
    } else if (learner_type == std::string("voting")) {
      return new VotingParallelTreeLearner<CUDATreeLearner>(config);
    }
52
53
54
55
56
57
58
59
60
61
  } else if (device_type == std::string("cuda_exp")) {
    if (learner_type == std::string("serial")) {
      if (config->num_gpu == 1) {
        return new CUDASingleGPUTreeLearner(config);
      } else {
        Log::Fatal("cuda_exp only supports training on a single GPU.");
      }
    } else {
      Log::Fatal("cuda_exp only supports training on a single machine.");
    }
Guolin Ke's avatar
Guolin Ke committed
62
63
64
65
66
  }
  return nullptr;
}

}  // namespace LightGBM