Commit 574d5342 authored by Guolin Ke's avatar Guolin Ke
Browse files

better performance for reset parameters

parent 25200c3a
...@@ -100,8 +100,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -100,8 +100,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
num_data_ = train_data->num_data(); num_data_ = train_data->num_data();
// create buffer for gradients and hessians // create buffer for gradients and hessians
if (object_function_ != nullptr) { if (object_function_ != nullptr) {
gradients_ = std::vector<score_t>(num_data_ * num_class_); gradients_.resize(num_data_ * num_class_);
hessians_ = std::vector<score_t>(num_data_ * num_class_); hessians_.resize(num_data_ * num_class_);
} }
// get max feature index // get max feature index
max_feature_idx_ = train_data->num_total_features() - 1; max_feature_idx_ = train_data->num_total_features() - 1;
...@@ -114,8 +114,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -114,8 +114,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
|| (gbdt_config_->bagging_fraction != new_config->bagging_fraction)) { || (gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
// if need bagging, create buffer // if need bagging, create buffer
if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) { if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_); out_of_bag_data_indices_.resize(num_data_);
bag_data_indices_ = std::vector<data_size_t>(num_data_); bag_data_indices_.resize(num_data_);
} else { } else {
out_of_bag_data_cnt_ = 0; out_of_bag_data_cnt_ = 0;
out_of_bag_data_indices_.clear(); out_of_bag_data_indices_.clear();
......
...@@ -16,8 +16,8 @@ namespace LightGBM { ...@@ -16,8 +16,8 @@ namespace LightGBM {
*/ */
class DataPartition { class DataPartition {
public: public:
DataPartition(data_size_t num_data, int num_leafs) DataPartition(data_size_t num_data, int num_leaves)
:num_data_(num_data), num_leaves_(num_leafs) { :num_data_(num_data), num_leaves_(num_leaves) {
leaf_begin_.resize(num_leaves_); leaf_begin_.resize(num_leaves_);
leaf_count_.resize(num_leaves_); leaf_count_.resize(num_leaves_);
indices_.resize(num_data_); indices_.resize(num_data_);
...@@ -35,6 +35,13 @@ public: ...@@ -35,6 +35,13 @@ public:
left_write_pos_buf_.resize(num_threads_); left_write_pos_buf_.resize(num_threads_);
right_write_pos_buf_.resize(num_threads_); right_write_pos_buf_.resize(num_threads_);
} }
void ResetLeaves(int num_leaves) {
num_leaves_ = num_leaves;
leaf_begin_.resize(num_leaves_);
leaf_count_.resize(num_leaves_);
}
~DataPartition() { ~DataPartition() {
} }
......
...@@ -364,9 +364,9 @@ public: ...@@ -364,9 +364,9 @@ public:
} }
is_enough_ = (cache_size_ == total_size_); is_enough_ = (cache_size_ == total_size_);
if (!is_enough_) { if (!is_enough_) {
mapper_ = std::vector<int>(total_size_); mapper_.resize(total_size_);
inverse_mapper_ = std::vector<int>(cache_size_); inverse_mapper_.resize(cache_size_);
last_used_time_ = std::vector<int>(cache_size_); last_used_time_.resize(cache_size_);
ResetMap(); ResetMap();
} }
} }
......
...@@ -105,7 +105,7 @@ void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) { ...@@ -105,7 +105,7 @@ void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
// push split information for all leaves // push split information for all leaves
best_split_per_leaf_.resize(tree_config_->num_leaves); best_split_per_leaf_.resize(tree_config_->num_leaves);
data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves)); data_partition_->ResetLeaves(tree_config_->num_leaves);
} else { } else {
tree_config_ = tree_config; tree_config_ = tree_config;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment