Unverified Commit c7d3ac1b authored by dragonbra's avatar dragonbra Committed by GitHub
Browse files

[GPU] Add support for linear tree with device=gpu (#6567)



* basic gpu_linear_tree_learner implementation

* corresponding config of gpu linear tree

* Update src/io/config.cpp
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* work around for gpu linear tree learner without gpu enabled

* add #endif

* add #ifdef USE_GPU

* fix lint problems

* fix compilation when USE_GPU is OFF

* add destructor

* add gpu_linear_tree_learner.cpp in make file list

* use template for linear tree learner

---------
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarshiyu1994 <shiyu_k1994@qq.com>
parent 283cdde4
...@@ -417,9 +417,9 @@ void Config::CheckParamConflict(const std::unordered_map<std::string, std::strin ...@@ -417,9 +417,9 @@ void Config::CheckParamConflict(const std::unordered_map<std::string, std::strin
} }
// linear tree learner must be serial type and run on CPU device // linear tree learner must be serial type and run on CPU device
if (linear_tree) { if (linear_tree) {
if (device_type != std::string("cpu")) { if (device_type != std::string("cpu") && device_type != std::string("gpu")) {
device_type = "cpu"; device_type = "cpu";
Log::Warning("Linear tree learner only works with CPU."); Log::Warning("Linear tree learner only works with CPU and GPU. Falling back to CPU now.");
} }
if (tree_learner != std::string("serial")) { if (tree_learner != std::string("serial")) {
tree_learner = "serial"; tree_learner = "serial";
......
...@@ -10,20 +10,22 @@ ...@@ -10,20 +10,22 @@
namespace LightGBM { namespace LightGBM {
void LinearTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) { template <typename TREE_LEARNER_TYPE>
SerialTreeLearner::Init(train_data, is_constant_hessian); void LinearTreeLearner<TREE_LEARNER_TYPE>::Init(const Dataset* train_data, bool is_constant_hessian) {
LinearTreeLearner::InitLinear(train_data, config_->num_leaves); TREE_LEARNER_TYPE::Init(train_data, is_constant_hessian);
LinearTreeLearner::InitLinear(train_data, this->config_->num_leaves);
} }
void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leaves) { template <typename TREE_LEARNER_TYPE>
void LinearTreeLearner<TREE_LEARNER_TYPE>::InitLinear(const Dataset* train_data, const int max_leaves) {
leaf_map_ = std::vector<int>(train_data->num_data(), -1); leaf_map_ = std::vector<int>(train_data->num_data(), -1);
contains_nan_ = std::vector<int8_t>(train_data->num_features(), 0); contains_nan_ = std::vector<int8_t>(train_data->num_features(), 0);
// identify features containing nans // identify features containing nans
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int feat = 0; feat < train_data->num_features(); ++feat) { for (int feat = 0; feat < train_data->num_features(); ++feat) {
auto bin_mapper = train_data_->FeatureBinMapper(feat); auto bin_mapper = this->train_data_->FeatureBinMapper(feat);
if (bin_mapper->bin_type() == BinType::NumericalBin) { if (bin_mapper->bin_type() == BinType::NumericalBin) {
const float* feat_ptr = train_data_->raw_index(feat); const float* feat_ptr = this->train_data_->raw_index(feat);
for (int i = 0; i < train_data->num_data(); ++i) { for (int i = 0; i < train_data->num_data(); ++i) {
if (std::isnan(feat_ptr[i])) { if (std::isnan(feat_ptr[i])) {
contains_nan_[feat] = 1; contains_nan_[feat] = 1;
...@@ -40,7 +42,7 @@ void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leav ...@@ -40,7 +42,7 @@ void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leav
} }
} }
// preallocate the matrix used to calculate linear model coefficients // preallocate the matrix used to calculate linear model coefficients
int max_num_feat = std::min(max_leaves, train_data_->num_numeric_features()); int max_num_feat = std::min(max_leaves, this->train_data_->num_numeric_features());
XTHX_.clear(); XTHX_.clear();
XTg_.clear(); XTg_.clear();
for (int i = 0; i < max_leaves; ++i) { for (int i = 0; i < max_leaves; ++i) {
...@@ -59,25 +61,26 @@ void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leav ...@@ -59,25 +61,26 @@ void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leav
} }
} }
Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) { template <typename TREE_LEARNER_TYPE>
Tree* LinearTreeLearner<TREE_LEARNER_TYPE>::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) {
Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer); Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer);
gradients_ = gradients; this->gradients_ = gradients;
hessians_ = hessians; this->hessians_ = hessians;
int num_threads = OMP_NUM_THREADS(); int num_threads = OMP_NUM_THREADS();
if (share_state_->num_threads != num_threads && share_state_->num_threads > 0) { if (this->share_state_->num_threads != num_threads && this->share_state_->num_threads > 0) {
Log::Warning( Log::Warning(
"Detected that num_threads changed during training (from %d to %d), " "Detected that num_threads changed during training (from %d to %d), "
"it may cause unexpected errors.", "it may cause unexpected errors.",
share_state_->num_threads, num_threads); this->share_state_->num_threads, num_threads);
} }
share_state_->num_threads = num_threads; this->share_state_->num_threads = num_threads;
// some initial works before training // some initial works before training
BeforeTrain(); this->BeforeTrain();
auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves, true, true)); auto tree = std::unique_ptr<Tree>(new Tree(this->config_->num_leaves, true, true));
auto tree_ptr = tree.get(); auto tree_ptr = tree.get();
constraints_->ShareTreePointer(tree_ptr); this->constraints_->ShareTreePointer(tree_ptr);
// root leaf // root leaf
int left_leaf = 0; int left_leaf = 0;
...@@ -85,25 +88,25 @@ Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians ...@@ -85,25 +88,25 @@ Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians
// only root leaf can be splitted on first time // only root leaf can be splitted on first time
int right_leaf = -1; int right_leaf = -1;
int init_splits = ForceSplits(tree_ptr, &left_leaf, &right_leaf, &cur_depth); int init_splits = this->ForceSplits(tree_ptr, &left_leaf, &right_leaf, &cur_depth);
for (int split = init_splits; split < config_->num_leaves - 1; ++split) { for (int split = init_splits; split < this->config_->num_leaves - 1; ++split) {
// some initial works before finding best split // some initial works before finding best split
if (BeforeFindBestSplit(tree_ptr, left_leaf, right_leaf)) { if (this->BeforeFindBestSplit(tree_ptr, left_leaf, right_leaf)) {
// find best threshold for every feature // find best threshold for every feature
FindBestSplits(tree_ptr); this->FindBestSplits(tree_ptr);
} }
// Get a leaf with max split gain // Get a leaf with max split gain
int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_)); int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(this->best_split_per_leaf_));
// Get split information for best leaf // Get split information for best leaf
const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf]; const SplitInfo& best_leaf_SplitInfo = this->best_split_per_leaf_[best_leaf];
// cannot split, quit // cannot split, quit
if (best_leaf_SplitInfo.gain <= 0.0) { if (best_leaf_SplitInfo.gain <= 0.0) {
Log::Warning("No further splits with positive gain, best gain: %f", best_leaf_SplitInfo.gain); Log::Warning("No further splits with positive gain, best gain: %f", best_leaf_SplitInfo.gain);
break; break;
} }
// split tree with best leaf // split tree with best leaf
Split(tree_ptr, best_leaf, &left_leaf, &right_leaf); this->Split(tree_ptr, best_leaf, &left_leaf, &right_leaf);
cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf)); cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf));
} }
...@@ -120,21 +123,22 @@ Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians ...@@ -120,21 +123,22 @@ Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians
GetLeafMap(tree_ptr); GetLeafMap(tree_ptr);
if (has_nan) { if (has_nan) {
CalculateLinear<true>(tree_ptr, false, gradients_, hessians_, is_first_tree); CalculateLinear<true>(tree_ptr, false, this->gradients_, this->hessians_, is_first_tree);
} else { } else {
CalculateLinear<false>(tree_ptr, false, gradients_, hessians_, is_first_tree); CalculateLinear<false>(tree_ptr, false, this->gradients_, this->hessians_, is_first_tree);
} }
Log::Debug("Trained a tree with leaves = %d and depth = %d", tree->num_leaves(), cur_depth); Log::Debug("Trained a tree with leaves = %d and depth = %d", tree->num_leaves(), cur_depth);
return tree.release(); return tree.release();
} }
Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const { template <typename TREE_LEARNER_TYPE>
auto tree = SerialTreeLearner::FitByExistingTree(old_tree, gradients, hessians); Tree* LinearTreeLearner<TREE_LEARNER_TYPE>::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const {
auto tree = TREE_LEARNER_TYPE::FitByExistingTree(old_tree, gradients, hessians);
bool has_nan = false; bool has_nan = false;
if (any_nan_) { if (any_nan_) {
for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { for (int i = 0; i < tree->num_leaves() - 1 ; ++i) {
if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) { if (contains_nan_[this->train_data_->InnerFeatureIndex(tree->split_feature(i))]) {
has_nan = true; has_nan = true;
break; break;
} }
...@@ -149,28 +153,31 @@ Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* ...@@ -149,28 +153,31 @@ Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t*
return tree; return tree;
} }
Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred, template <typename TREE_LEARNER_TYPE>
Tree* LinearTreeLearner<TREE_LEARNER_TYPE>::FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred,
const score_t* gradients, const score_t *hessians) const { const score_t* gradients, const score_t *hessians) const {
data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves()); this->data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves());
return LinearTreeLearner::FitByExistingTree(old_tree, gradients, hessians); return LinearTreeLearner::FitByExistingTree(old_tree, gradients, hessians);
} }
void LinearTreeLearner::GetLeafMap(Tree* tree) const { template <typename TREE_LEARNER_TYPE>
void LinearTreeLearner<TREE_LEARNER_TYPE>::GetLeafMap(Tree* tree) const {
std::fill(leaf_map_.begin(), leaf_map_.end(), -1); std::fill(leaf_map_.begin(), leaf_map_.end(), -1);
// map data to leaf number // map data to leaf number
const data_size_t* ind = data_partition_->indices(); const data_size_t* ind = this->data_partition_->indices();
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(dynamic) #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(dynamic)
for (int i = 0; i < tree->num_leaves(); ++i) { for (int i = 0; i < tree->num_leaves(); ++i) {
data_size_t idx = data_partition_->leaf_begin(i); data_size_t idx = this->data_partition_->leaf_begin(i);
for (int j = 0; j < data_partition_->leaf_count(i); ++j) { for (int j = 0; j < this->data_partition_->leaf_count(i); ++j) {
leaf_map_[ind[idx + j]] = i; leaf_map_[ind[idx + j]] = i;
} }
} }
} }
template<bool HAS_NAN> template<typename TREE_LEARNER_TYPE>
void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const { template <bool HAS_NAN>
void LinearTreeLearner<TREE_LEARNER_TYPE>::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const {
tree->SetIsLinear(true); tree->SetIsLinear(true);
int num_leaves = tree->num_leaves(); int num_leaves = tree->num_leaves();
int num_threads = OMP_NUM_THREADS(); int num_threads = OMP_NUM_THREADS();
...@@ -209,11 +216,11 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t ...@@ -209,11 +216,11 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t
std::vector<int> numerical_features; std::vector<int> numerical_features;
std::vector<const float*> data_ptr; std::vector<const float*> data_ptr;
for (size_t j = 0; j < raw_features.size(); ++j) { for (size_t j = 0; j < raw_features.size(); ++j) {
int feat = train_data_->InnerFeatureIndex(raw_features[j]); int feat = this->train_data_->InnerFeatureIndex(raw_features[j]);
auto bin_mapper = train_data_->FeatureBinMapper(feat); auto bin_mapper = this->train_data_->FeatureBinMapper(feat);
if (bin_mapper->bin_type() == BinType::NumericalBin) { if (bin_mapper->bin_type() == BinType::NumericalBin) {
numerical_features.push_back(feat); numerical_features.push_back(feat);
data_ptr.push_back(train_data_->raw_index(feat)); data_ptr.push_back(this->train_data_->raw_index(feat));
} }
} }
leaf_features.push_back(numerical_features); leaf_features.push_back(numerical_features);
...@@ -245,12 +252,12 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t ...@@ -245,12 +252,12 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t
} }
} }
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel num_threads(OMP_NUM_THREADS()) if (num_data_ > 1024) #pragma omp parallel num_threads(OMP_NUM_THREADS()) if (this->num_data_ > 1024)
{ {
std::vector<float> curr_row(max_num_features + 1); std::vector<float> curr_row(max_num_features + 1);
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
#pragma omp for schedule(static) #pragma omp for schedule(static)
for (int i = 0; i < num_data_; ++i) { for (int i = 0; i < this->num_data_; ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
int leaf_num = leaf_map_[i]; int leaf_num = leaf_map_[i];
if (leaf_num < 0) { if (leaf_num < 0) {
...@@ -312,11 +319,11 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t ...@@ -312,11 +319,11 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t
} }
if (!HAS_NAN) { if (!HAS_NAN) {
for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
total_nonzero[leaf_num] = data_partition_->leaf_count(leaf_num); total_nonzero[leaf_num] = this->data_partition_->leaf_count(leaf_num);
} }
} }
double shrinkage = tree->shrinkage(); double shrinkage = tree->shrinkage();
double decay_rate = config_->refit_decay_rate; double decay_rate = this->config_->refit_decay_rate;
// copy into eigen matrices and solve // copy into eigen matrices and solve
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
...@@ -340,7 +347,7 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t ...@@ -340,7 +347,7 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t
XTHX_mat(feat1, feat2) = XTHX_[leaf_num][j]; XTHX_mat(feat1, feat2) = XTHX_[leaf_num][j];
XTHX_mat(feat2, feat1) = XTHX_mat(feat1, feat2); XTHX_mat(feat2, feat1) = XTHX_mat(feat1, feat2);
if ((feat1 == feat2) && (feat1 < num_feat)) { if ((feat1 == feat2) && (feat1 < num_feat)) {
XTHX_mat(feat1, feat2) += config_->linear_lambda; XTHX_mat(feat1, feat2) += this->config_->linear_lambda;
} }
++j; ++j;
} }
...@@ -366,7 +373,7 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t ...@@ -366,7 +373,7 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t
tree->SetLeafFeaturesInner(leaf_num, features_new); tree->SetLeafFeaturesInner(leaf_num, features_new);
std::vector<int> features_raw(features_new.size()); std::vector<int> features_raw(features_new.size());
for (size_t i = 0; i < features_new.size(); ++i) { for (size_t i = 0; i < features_new.size(); ++i) {
features_raw[i] = train_data_->RealFeatureIndex(features_new[i]); features_raw[i] = this->train_data_->RealFeatureIndex(features_new[i]);
} }
tree->SetLeafFeatures(leaf_num, features_raw); tree->SetLeafFeatures(leaf_num, features_raw);
tree->SetLeafCoeffs(leaf_num, coeffs_vec); tree->SetLeafCoeffs(leaf_num, coeffs_vec);
...@@ -378,4 +385,19 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t ...@@ -378,4 +385,19 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t
} }
} }
} }
template void LinearTreeLearner<SerialTreeLearner>::Init(const Dataset* train_data, bool is_constant_hessian);
template void LinearTreeLearner<SerialTreeLearner>::InitLinear(const Dataset* train_data, const int max_leaves);
template Tree* LinearTreeLearner<SerialTreeLearner>::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree);
template Tree* LinearTreeLearner<SerialTreeLearner>::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const;
template Tree* LinearTreeLearner<SerialTreeLearner>::FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred,
const score_t* gradients, const score_t *hessians) const;
template void LinearTreeLearner<GPUTreeLearner>::Init(const Dataset* train_data, bool is_constant_hessian);
template void LinearTreeLearner<GPUTreeLearner>::InitLinear(const Dataset* train_data, const int max_leaves);
template Tree* LinearTreeLearner<GPUTreeLearner>::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree);
template Tree* LinearTreeLearner<GPUTreeLearner>::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const;
template Tree* LinearTreeLearner<GPUTreeLearner>::FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred,
const score_t* gradients, const score_t *hessians) const;
} // namespace LightGBM } // namespace LightGBM
...@@ -11,13 +11,15 @@ ...@@ -11,13 +11,15 @@
#include <random> #include <random>
#include <vector> #include <vector>
#include "gpu_tree_learner.h"
#include "serial_tree_learner.h" #include "serial_tree_learner.h"
namespace LightGBM { namespace LightGBM {
class LinearTreeLearner: public SerialTreeLearner { template <typename TREE_LEARNER_TYPE>
class LinearTreeLearner: public TREE_LEARNER_TYPE {
public: public:
explicit LinearTreeLearner(const Config* config) : SerialTreeLearner(config) {} explicit LinearTreeLearner(const Config* config) : TREE_LEARNER_TYPE(config) {}
void Init(const Dataset* train_data, bool is_constant_hessian) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
...@@ -38,12 +40,12 @@ class LinearTreeLearner: public SerialTreeLearner { ...@@ -38,12 +40,12 @@ class LinearTreeLearner: public SerialTreeLearner {
void AddPredictionToScore(const Tree* tree, void AddPredictionToScore(const Tree* tree,
double* out_score) const override { double* out_score) const override {
CHECK_LE(tree->num_leaves(), data_partition_->num_leaves()); CHECK_LE(tree->num_leaves(), this->data_partition_->num_leaves());
bool has_nan = false; bool has_nan = false;
if (any_nan_) { if (any_nan_) {
for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { for (int i = 0; i < tree->num_leaves() - 1 ; ++i) {
// use split_feature because split_feature_inner doesn't work when refitting existing tree // use split_feature because split_feature_inner doesn't work when refitting existing tree
if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) { if (contains_nan_[this->train_data_->InnerFeatureIndex(tree->split_feature(i))]) {
has_nan = true; has_nan = true;
break; break;
} }
...@@ -69,13 +71,13 @@ class LinearTreeLearner: public SerialTreeLearner { ...@@ -69,13 +71,13 @@ class LinearTreeLearner: public SerialTreeLearner {
leaf_coeff[leaf_num] = tree->LeafCoeffs(leaf_num); leaf_coeff[leaf_num] = tree->LeafCoeffs(leaf_num);
leaf_output[leaf_num] = tree->LeafOutput(leaf_num); leaf_output[leaf_num] = tree->LeafOutput(leaf_num);
for (int feat : tree->LeafFeaturesInner(leaf_num)) { for (int feat : tree->LeafFeaturesInner(leaf_num)) {
feat_ptr[leaf_num].push_back(train_data_->raw_index(feat)); feat_ptr[leaf_num].push_back(this->train_data_->raw_index(feat));
} }
leaf_num_features[leaf_num] = static_cast<int>(feat_ptr[leaf_num].size()); leaf_num_features[leaf_num] = static_cast<int>(feat_ptr[leaf_num].size());
} }
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) if (num_data_ > 1024) #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) if (this->num_data_ > 1024)
for (int i = 0; i < num_data_; ++i) { for (int i = 0; i < this->num_data_; ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
int leaf_num = leaf_map_[i]; int leaf_num = leaf_map_[i];
if (leaf_num < 0) { if (leaf_num < 0) {
......
...@@ -17,7 +17,7 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con ...@@ -17,7 +17,7 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con
if (device_type == std::string("cpu")) { if (device_type == std::string("cpu")) {
if (learner_type == std::string("serial")) { if (learner_type == std::string("serial")) {
if (config->linear_tree) { if (config->linear_tree) {
return new LinearTreeLearner(config); return new LinearTreeLearner<SerialTreeLearner>(config);
} else { } else {
return new SerialTreeLearner(config); return new SerialTreeLearner(config);
} }
...@@ -30,7 +30,11 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con ...@@ -30,7 +30,11 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con
} }
} else if (device_type == std::string("gpu")) { } else if (device_type == std::string("gpu")) {
if (learner_type == std::string("serial")) { if (learner_type == std::string("serial")) {
if (config->linear_tree) {
return new LinearTreeLearner<GPUTreeLearner>(config);
} else {
return new GPUTreeLearner(config); return new GPUTreeLearner(config);
}
} else if (learner_type == std::string("feature")) { } else if (learner_type == std::string("feature")) {
return new FeatureParallelTreeLearner<GPUTreeLearner>(config); return new FeatureParallelTreeLearner<GPUTreeLearner>(config);
} else if (learner_type == std::string("data")) { } else if (learner_type == std::string("data")) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment