#include #include "serial_tree_learner.h" #include "gpu_tree_learner.h" #include "parallel_tree_learner.h" namespace LightGBM { TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type, const Config* config) { if (device_type == std::string("cpu")) { if (learner_type == std::string("serial")) { return new SerialTreeLearner(config); } else if (learner_type == std::string("feature")) { return new FeatureParallelTreeLearner(config); } else if (learner_type == std::string("data")) { return new DataParallelTreeLearner(config); } else if (learner_type == std::string("voting")) { return new VotingParallelTreeLearner(config); } } else if (device_type == std::string("gpu")) { if (learner_type == std::string("serial")) { return new GPUTreeLearner(config); } else if (learner_type == std::string("feature")) { return new FeatureParallelTreeLearner(config); } else if (learner_type == std::string("data")) { return new DataParallelTreeLearner(config); } else if (learner_type == std::string("voting")) { return new VotingParallelTreeLearner(config); } } return nullptr; } } // namespace LightGBM