Commit c8fbd42b authored by Guolin Ke's avatar Guolin Ke
Browse files

use subset to speed up bagging

parent 873528c1
...@@ -328,7 +328,7 @@ public: ...@@ -328,7 +328,7 @@ public:
return used_feature_map_[col_idx]; return used_feature_map_[col_idx];
} }
Dataset* Subset(const data_size_t* used_indices, data_size_t num_used_indices, bool is_enable_sparse) const; Dataset* Subset(const data_size_t* used_indices, data_size_t num_used_indices, bool is_enable_sparse, bool need_meta_data) const;
LIGHTGBM_EXPORT void FinishLoad(); LIGHTGBM_EXPORT void FinishLoad();
......
...@@ -27,6 +27,8 @@ public: ...@@ -27,6 +27,8 @@ public:
*/ */
virtual void Init(const Dataset* train_data) = 0; virtual void Init(const Dataset* train_data) = 0;
virtual void ResetTrainingData(const Dataset* train_data) = 0;
/*! /*!
* \brief Reset tree configs * \brief Reset tree configs
* \param tree_config config of tree * \param tree_config config of tree
......
...@@ -134,10 +134,17 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -134,10 +134,17 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
right_cnts_buf_.resize(num_threads_); right_cnts_buf_.resize(num_threads_);
left_write_pos_buf_.resize(num_threads_); left_write_pos_buf_.resize(num_threads_);
right_write_pos_buf_.resize(num_threads_); right_write_pos_buf_.resize(num_threads_);
double average_bag_rate = new_config->bagging_fraction / new_config->bagging_freq;
is_use_subset_ = false;
if (average_bag_rate < 0.3) {
is_use_subset_ = true;
Log::Debug("use subset for bagging");
}
} else { } else {
bag_data_cnt_ = num_data_; bag_data_cnt_ = num_data_;
bag_data_indices_.clear(); bag_data_indices_.clear();
tmp_indices_.clear(); tmp_indices_.clear();
is_use_subset_ = false;
} }
} }
train_data_ = train_data; train_data_ = train_data;
...@@ -196,6 +203,7 @@ data_size_t GBDT::BaggingHelper(data_size_t start, data_size_t cnt, data_size_t* ...@@ -196,6 +203,7 @@ data_size_t GBDT::BaggingHelper(data_size_t start, data_size_t cnt, data_size_t*
buffer[bag_data_cnt + cur_right_cnt++] = start + i; buffer[bag_data_cnt + cur_right_cnt++] = start + i;
} }
} }
CHECK(buffer[bag_data_cnt - 1] > buffer[bag_data_cnt]);
CHECK(cur_left_cnt == bag_data_cnt); CHECK(cur_left_cnt == bag_data_cnt);
return cur_left_cnt; return cur_left_cnt;
} }
...@@ -240,15 +248,24 @@ void GBDT::Bagging(int iter) { ...@@ -240,15 +248,24 @@ void GBDT::Bagging(int iter) {
tmp_indices_.data() + offsets_buf_[i] + left_cnts_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t)); tmp_indices_.data() + offsets_buf_[i] + left_cnts_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t));
} }
} }
bag_data_cnt_ = left_cnt;
CHECK(bag_data_indices_[bag_data_cnt_ - 1] > bag_data_indices_[bag_data_cnt_]);
Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_); Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
// set bagging data to tree learner // set bagging data to tree learner
tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_); if (!is_use_subset_) {
tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
} else {
// get subset
tmp_subset_.reset(train_data_->Subset(bag_data_indices_.data(), bag_data_cnt_, false, false));
tmp_subset_->FinishLoad();
tree_learner_->ResetTrainingData(tmp_subset_.get());
}
} }
} }
void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) { void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
// we need to predict out-of-bag socres of data for boosting // we need to predict out-of-bag socres of data for boosting
if (num_data_ - bag_data_cnt_ > 0) { if (num_data_ - bag_data_cnt_ > 0 && !is_use_subset_) {
train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, curr_class); train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, curr_class);
} }
} }
...@@ -262,8 +279,24 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -262,8 +279,24 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
} }
// bagging logic // bagging logic
Bagging(iter_); Bagging(iter_);
if (is_use_subset_ && bag_data_cnt_ < num_data_) {
if (gradients_.empty()) {
size_t total_size = static_cast<size_t>(num_data_) * num_class_;
gradients_.resize(total_size);
hessians_.resize(total_size);
}
// get sub gradients
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto bias = curr_class * num_data_;
for (int i = 0; i < bag_data_cnt_; ++i) {
gradients_[bias + i] = gradient[bias + bag_data_indices_[i]];
hessians_[bias + i] = hessian[bias + bag_data_indices_[i]];
}
}
gradient = gradients_.data();
hessian = hessians_.data();
}
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
// train a new tree // train a new tree
std::unique_ptr<Tree> new_tree(tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_)); std::unique_ptr<Tree> new_tree(tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
// if cannot learn a new tree, then stop // if cannot learn a new tree, then stop
...@@ -328,7 +361,11 @@ bool GBDT::EvalAndCheckEarlyStopping() { ...@@ -328,7 +361,11 @@ bool GBDT::EvalAndCheckEarlyStopping() {
void GBDT::UpdateScore(const Tree* tree, const int curr_class) { void GBDT::UpdateScore(const Tree* tree, const int curr_class) {
// update training score // update training score
train_score_updater_->AddScore(tree_learner_.get(), curr_class); if (!is_use_subset_) {
train_score_updater_->AddScore(tree_learner_.get(), curr_class);
} else {
train_score_updater_->AddScore(tree, curr_class);
}
// update validation score // update validation score
for (auto& score_updater : valid_score_updater_) { for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(tree, curr_class); score_updater->AddScore(tree, curr_class);
......
...@@ -339,6 +339,8 @@ protected: ...@@ -339,6 +339,8 @@ protected:
std::vector<data_size_t> left_write_pos_buf_; std::vector<data_size_t> left_write_pos_buf_;
/*! \brief Buffer for multi-threading bagging */ /*! \brief Buffer for multi-threading bagging */
std::vector<data_size_t> right_write_pos_buf_; std::vector<data_size_t> right_write_pos_buf_;
std::unique_ptr<Dataset> tmp_subset_;
bool is_use_subset_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -488,7 +488,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset( ...@@ -488,7 +488,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
auto ret = std::unique_ptr<Dataset>( auto ret = std::unique_ptr<Dataset>(
full_dataset->Subset(used_row_indices, full_dataset->Subset(used_row_indices,
num_used_row_indices, num_used_row_indices,
io_config.is_enable_sparse)); io_config.is_enable_sparse,
true));
ret->FinishLoad(); ret->FinishLoad();
*out = ret.release(); *out = ret.release();
API_END(); API_END();
......
...@@ -22,6 +22,7 @@ Dataset::Dataset() { ...@@ -22,6 +22,7 @@ Dataset::Dataset() {
} }
Dataset::Dataset(data_size_t num_data) { Dataset::Dataset(data_size_t num_data) {
data_filename_ = "noname";
num_data_ = num_data; num_data_ = num_data;
metadata_.Init(num_data_, -1, -1); metadata_.Init(num_data_, -1, -1);
} }
...@@ -56,7 +57,8 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset, bool is_enable_spars ...@@ -56,7 +57,8 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset, bool is_enable_spars
label_idx_ = dataset->label_idx_; label_idx_ = dataset->label_idx_;
} }
Dataset* Dataset::Subset(const data_size_t* used_indices, data_size_t num_used_indices, bool is_enable_sparse) const { Dataset* Dataset::Subset(const data_size_t* used_indices, data_size_t num_used_indices,
bool is_enable_sparse, bool need_meta_data) const {
auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_indices)); auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_indices));
ret->CopyFeatureMapperFrom(this, is_enable_sparse); ret->CopyFeatureMapperFrom(this, is_enable_sparse);
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
...@@ -66,7 +68,9 @@ Dataset* Dataset::Subset(const data_size_t* used_indices, data_size_t num_used_i ...@@ -66,7 +68,9 @@ Dataset* Dataset::Subset(const data_size_t* used_indices, data_size_t num_used_i
ret->features_[fidx]->PushBin(0, i, iterator->Get(used_indices[i])); ret->features_[fidx]->PushBin(0, i, iterator->Get(used_indices[i]));
} }
} }
ret->metadata_.Init(metadata_, used_indices, num_used_indices); if (need_meta_data) {
ret->metadata_.Init(metadata_, used_indices, num_used_indices);
}
return ret.release(); return ret.release();
} }
......
...@@ -96,7 +96,7 @@ void DataParallelTreeLearner::BeforeTrain() { ...@@ -96,7 +96,7 @@ void DataParallelTreeLearner::BeforeTrain() {
// sync global data sumup info // sync global data sumup info
std::tuple<data_size_t, double, double> data(smaller_leaf_splits_->num_data_in_leaf(), std::tuple<data_size_t, double, double> data(smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians()); smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians());
int size = sizeof(data); int size = sizeof(data);
std::memcpy(input_buffer_.data(), &data, size); std::memcpy(input_buffer_.data(), &data, size);
// global sumup reduce // global sumup reduce
...@@ -126,65 +126,67 @@ void DataParallelTreeLearner::BeforeTrain() { ...@@ -126,65 +126,67 @@ void DataParallelTreeLearner::BeforeTrain() {
void DataParallelTreeLearner::FindBestThresholds() { void DataParallelTreeLearner::FindBestThresholds() {
// construct local histograms // construct local histograms
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) { for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if ((!is_feature_used_.empty() && is_feature_used_[feature_index] == false)) continue; if ((!is_feature_used_.empty() && is_feature_used_[feature_index] == false)) continue;
// construct histograms for smaller leaf // construct histograms for smaller leaf
if (ordered_bins_[feature_index] == nullptr) { if (ordered_bins_[feature_index] == nullptr) {
smaller_leaf_histogram_array_[feature_index].Construct(smaller_leaf_splits_->data_indices(), smaller_leaf_histogram_array_[feature_index].Construct(
smaller_leaf_splits_->num_data_in_leaf(), train_data_->FeatureAt(feature_index)->bin_data(),
smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->data_indices(),
smaller_leaf_splits_->sum_hessians(), smaller_leaf_splits_->num_data_in_leaf(),
ptr_to_ordered_gradients_smaller_leaf_, smaller_leaf_splits_->sum_gradients(),
ptr_to_ordered_hessians_smaller_leaf_); smaller_leaf_splits_->sum_hessians(),
ptr_to_ordered_gradients_smaller_leaf_,
ptr_to_ordered_hessians_smaller_leaf_);
} else { } else {
smaller_leaf_histogram_array_[feature_index].Construct(ordered_bins_[feature_index].get(), smaller_leaf_histogram_array_[feature_index].Construct(ordered_bins_[feature_index].get(),
smaller_leaf_splits_->LeafIndex(), smaller_leaf_splits_->LeafIndex(),
smaller_leaf_splits_->num_data_in_leaf(), smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_gradients(),
smaller_leaf_splits_->sum_hessians(), smaller_leaf_splits_->sum_hessians(),
gradients_, gradients_,
hessians_); hessians_);
} }
// copy to buffer // copy to buffer
std::memcpy(input_buffer_.data() + buffer_write_start_pos_[feature_index], std::memcpy(input_buffer_.data() + buffer_write_start_pos_[feature_index],
smaller_leaf_histogram_array_[feature_index].HistogramData(), smaller_leaf_histogram_array_[feature_index].HistogramData(),
smaller_leaf_histogram_array_[feature_index].SizeOfHistgram()); smaller_leaf_histogram_array_[feature_index].SizeOfHistgram());
} }
// Reduce scatter for histogram // Reduce scatter for histogram
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, block_start_.data(), Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, block_start_.data(),
block_len_.data(), output_buffer_.data(), &HistogramBinEntry::SumReducer); block_len_.data(), output_buffer_.data(), &HistogramBinEntry::SumReducer);
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) { for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_aggregated_[feature_index]) continue; if (!is_feature_aggregated_[feature_index]) continue;
// copy global sumup info // copy global sumup info
smaller_leaf_histogram_array_[feature_index].SetSumup( smaller_leaf_histogram_array_[feature_index].SetSumup(
GetGlobalDataCountInLeaf(smaller_leaf_splits_->LeafIndex()), GetGlobalDataCountInLeaf(smaller_leaf_splits_->LeafIndex()),
smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_gradients(),
smaller_leaf_splits_->sum_hessians()); smaller_leaf_splits_->sum_hessians());
// restore global histograms from buffer // restore global histograms from buffer
smaller_leaf_histogram_array_[feature_index].FromMemory( smaller_leaf_histogram_array_[feature_index].FromMemory(
output_buffer_.data() + buffer_read_start_pos_[feature_index]); output_buffer_.data() + buffer_read_start_pos_[feature_index]);
// find best threshold for smaller child // find best threshold for smaller child
smaller_leaf_histogram_array_[feature_index].FindBestThreshold( smaller_leaf_histogram_array_[feature_index].FindBestThreshold(
&smaller_leaf_splits_->BestSplitPerFeature()[feature_index]); &smaller_leaf_splits_->BestSplitPerFeature()[feature_index]);
// only root leaf // only root leaf
if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) continue; if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) continue;
// construct histgroms for large leaf, we init larger leaf as the parent, so we can just subtract the smaller leaf's histograms // construct histgroms for large leaf, we init larger leaf as the parent, so we can just subtract the smaller leaf's histograms
larger_leaf_histogram_array_[feature_index].Subtract( larger_leaf_histogram_array_[feature_index].Subtract(
smaller_leaf_histogram_array_[feature_index]); smaller_leaf_histogram_array_[feature_index]);
// set sumup info for histogram // set sumup info for histogram
larger_leaf_histogram_array_[feature_index].SetSumup( larger_leaf_histogram_array_[feature_index].SetSumup(
GetGlobalDataCountInLeaf(larger_leaf_splits_->LeafIndex()), GetGlobalDataCountInLeaf(larger_leaf_splits_->LeafIndex()),
larger_leaf_splits_->sum_gradients(), larger_leaf_splits_->sum_hessians()); larger_leaf_splits_->sum_gradients(), larger_leaf_splits_->sum_hessians());
// find best threshold for larger child // find best threshold for larger child
larger_leaf_histogram_array_[feature_index].FindBestThreshold( larger_leaf_histogram_array_[feature_index].FindBestThreshold(
&larger_leaf_splits_->BestSplitPerFeature()[feature_index]); &larger_leaf_splits_->BestSplitPerFeature()[feature_index]);
} }
} }
...@@ -214,7 +216,7 @@ void DataParallelTreeLearner::FindBestSplitsForLeaves() { ...@@ -214,7 +216,7 @@ void DataParallelTreeLearner::FindBestSplitsForLeaves() {
std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best, sizeof(SplitInfo)); std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best, sizeof(SplitInfo));
Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo), Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo),
output_buffer_.data(), &SplitInfo::MaxReducer); output_buffer_.data(), &SplitInfo::MaxReducer);
std::memcpy(&smaller_best, output_buffer_.data(), sizeof(SplitInfo)); std::memcpy(&smaller_best, output_buffer_.data(), sizeof(SplitInfo));
std::memcpy(&larger_best, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo)); std::memcpy(&larger_best, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
......
...@@ -41,7 +41,12 @@ public: ...@@ -41,7 +41,12 @@ public:
leaf_begin_.resize(num_leaves_); leaf_begin_.resize(num_leaves_);
leaf_count_.resize(num_leaves_); leaf_count_.resize(num_leaves_);
} }
void ResetNumData(int num_data) {
num_data_ = num_data;
indices_.resize(num_data_);
temp_left_indices_.resize(num_data_);
temp_right_indices_.resize(num_data_);
}
~DataPartition() { ~DataPartition() {
} }
......
...@@ -31,7 +31,6 @@ public: ...@@ -31,7 +31,6 @@ public:
void Init(const Feature* feature, int feature_idx, const TreeConfig* tree_config) { void Init(const Feature* feature, int feature_idx, const TreeConfig* tree_config) {
feature_idx_ = feature_idx; feature_idx_ = feature_idx;
tree_config_ = tree_config; tree_config_ = tree_config;
bin_data_ = feature->bin_data();
num_bins_ = feature->num_bin(); num_bins_ = feature->num_bin();
data_.resize(num_bins_); data_.resize(num_bins_);
if (feature->bin_type() == BinType::NumericalBin) { if (feature->bin_type() == BinType::NumericalBin) {
...@@ -51,13 +50,13 @@ public: ...@@ -51,13 +50,13 @@ public:
* \param ordered_hessians Ordered hessians * \param ordered_hessians Ordered hessians
* \param data_indices data indices of current leaf * \param data_indices data indices of current leaf
*/ */
void Construct(const data_size_t* data_indices, data_size_t num_data, double sum_gradients, void Construct(const Bin* bin_data, const data_size_t* data_indices, data_size_t num_data, double sum_gradients,
double sum_hessians, const score_t* ordered_gradients, const score_t* ordered_hessians) { double sum_hessians, const score_t* ordered_gradients, const score_t* ordered_hessians) {
std::memset(data_.data(), 0, sizeof(HistogramBinEntry)* num_bins_); std::memset(data_.data(), 0, sizeof(HistogramBinEntry)* num_bins_);
num_data_ = num_data; num_data_ = num_data;
sum_gradients_ = sum_gradients; sum_gradients_ = sum_gradients;
sum_hessians_ = sum_hessians + 2 * kEpsilon; sum_hessians_ = sum_hessians + 2 * kEpsilon;
bin_data_->ConstructHistogram(data_indices, num_data, ordered_gradients, ordered_hessians, data_.data()); bin_data->ConstructHistogram(data_indices, num_data, ordered_gradients, ordered_hessians, data_.data());
} }
/*! /*!
...@@ -315,8 +314,6 @@ private: ...@@ -315,8 +314,6 @@ private:
int feature_idx_; int feature_idx_;
/*! \brief pointer of tree config */ /*! \brief pointer of tree config */
const TreeConfig* tree_config_; const TreeConfig* tree_config_;
/*! \brief the bin data of current feature */
const Bin* bin_data_;
/*! \brief number of bin of histogram */ /*! \brief number of bin of histogram */
unsigned int num_bins_; unsigned int num_bins_;
/*! \brief sum of gradient of each bin */ /*! \brief sum of gradient of each bin */
......
...@@ -22,6 +22,10 @@ public: ...@@ -22,6 +22,10 @@ public:
best_split_per_feature_[i].feature = i; best_split_per_feature_[i].feature = i;
} }
} }
void ResetNumData(data_size_t num_data) {
num_data_ = num_data;
num_data_in_leaf_ = num_data;
}
~LeafSplits() { ~LeafSplits() {
} }
......
...@@ -83,6 +83,45 @@ void SerialTreeLearner::Init(const Dataset* train_data) { ...@@ -83,6 +83,45 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
Log::Info("Number of data: %d, number of features: %d", num_data_, num_features_); Log::Info("Number of data: %d, number of features: %d", num_data_, num_features_);
} }
void SerialTreeLearner::ResetTrainingData(const Dataset* train_data) {
train_data_ = train_data;
num_data_ = train_data_->num_data();
num_features_ = train_data_->num_features();
// initialize ordered_bins_ with nullptr
ordered_bins_.resize(num_features_);
// get ordered bin
#pragma omp parallel for schedule(guided)
for (int i = 0; i < num_features_; ++i) {
ordered_bins_[i].reset(train_data_->FeatureAt(i)->bin_data()->CreateOrderedBin());
}
has_ordered_bin_ = false;
// check existing for ordered bin
for (int i = 0; i < num_features_; ++i) {
if (ordered_bins_[i] != nullptr) {
has_ordered_bin_ = true;
break;
}
}
// initialize splits for leaf
smaller_leaf_splits_->ResetNumData(num_data_);
larger_leaf_splits_->ResetNumData(num_data_);
// initialize data partition
data_partition_->ResetNumData(num_data_);
is_feature_used_.resize(num_features_);
// initialize ordered gradients and hessians
ordered_gradients_.resize(num_data_);
ordered_hessians_.resize(num_data_);
// if has ordered bin, need to allocate a buffer to fast split
if (has_ordered_bin_) {
is_data_in_leaf_.resize(num_data_);
}
}
void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) { void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
if (tree_config_->num_leaves != tree_config->num_leaves) { if (tree_config_->num_leaves != tree_config->num_leaves) {
...@@ -351,7 +390,9 @@ void SerialTreeLearner::FindBestThresholds() { ...@@ -351,7 +390,9 @@ void SerialTreeLearner::FindBestThresholds() {
// construct histograms for smaller leaf // construct histograms for smaller leaf
if (ordered_bins_[feature_index] == nullptr) { if (ordered_bins_[feature_index] == nullptr) {
// if not use ordered bin // if not use ordered bin
smaller_leaf_histogram_array_[feature_index].Construct(smaller_leaf_splits_->data_indices(), smaller_leaf_histogram_array_[feature_index].Construct(
train_data_->FeatureAt(feature_index)->bin_data(),
smaller_leaf_splits_->data_indices(),
smaller_leaf_splits_->num_data_in_leaf(), smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_gradients(),
smaller_leaf_splits_->sum_hessians(), smaller_leaf_splits_->sum_hessians(),
...@@ -380,7 +421,9 @@ void SerialTreeLearner::FindBestThresholds() { ...@@ -380,7 +421,9 @@ void SerialTreeLearner::FindBestThresholds() {
} else { } else {
if (ordered_bins_[feature_index] == nullptr) { if (ordered_bins_[feature_index] == nullptr) {
// if not use ordered bin // if not use ordered bin
larger_leaf_histogram_array_[feature_index].Construct(larger_leaf_splits_->data_indices(), larger_leaf_histogram_array_[feature_index].Construct(
train_data_->FeatureAt(feature_index)->bin_data(),
larger_leaf_splits_->data_indices(),
larger_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf(),
larger_leaf_splits_->sum_gradients(), larger_leaf_splits_->sum_gradients(),
larger_leaf_splits_->sum_hessians(), larger_leaf_splits_->sum_hessians(),
......
...@@ -32,6 +32,8 @@ public: ...@@ -32,6 +32,8 @@ public:
void Init(const Dataset* train_data) override; void Init(const Dataset* train_data) override;
void ResetTrainingData(const Dataset* train_data) override;
void ResetConfig(const TreeConfig* tree_config) override; void ResetConfig(const TreeConfig* tree_config) override;
Tree* Train(const score_t* gradients, const score_t *hessians) override; Tree* Train(const score_t* gradients, const score_t *hessians) override;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment