"vscode:/vscode.git/clone" did not exist on "8ebef94cfe0627aef6def578fb0e6a2e082dbf1e"
Commit 574d5342 authored by Guolin Ke's avatar Guolin Ke
Browse files

better performance for reset parameters

parent 25200c3a
......@@ -100,8 +100,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
num_data_ = train_data->num_data();
// create buffer for gradients and hessians
if (object_function_ != nullptr) {
gradients_ = std::vector<score_t>(num_data_ * num_class_);
hessians_ = std::vector<score_t>(num_data_ * num_class_);
gradients_.resize(num_data_ * num_class_);
hessians_.resize(num_data_ * num_class_);
}
// get max feature index
max_feature_idx_ = train_data->num_total_features() - 1;
......@@ -114,8 +114,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
|| (gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
// if need bagging, create buffer
if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_);
bag_data_indices_ = std::vector<data_size_t>(num_data_);
out_of_bag_data_indices_.resize(num_data_);
bag_data_indices_.resize(num_data_);
} else {
out_of_bag_data_cnt_ = 0;
out_of_bag_data_indices_.clear();
......
......@@ -16,8 +16,8 @@ namespace LightGBM {
*/
class DataPartition {
public:
DataPartition(data_size_t num_data, int num_leafs)
:num_data_(num_data), num_leaves_(num_leafs) {
DataPartition(data_size_t num_data, int num_leaves)
:num_data_(num_data), num_leaves_(num_leaves) {
leaf_begin_.resize(num_leaves_);
leaf_count_.resize(num_leaves_);
indices_.resize(num_data_);
......@@ -35,6 +35,13 @@ public:
left_write_pos_buf_.resize(num_threads_);
right_write_pos_buf_.resize(num_threads_);
}
void ResetLeaves(int num_leaves) {
num_leaves_ = num_leaves;
leaf_begin_.resize(num_leaves_);
leaf_count_.resize(num_leaves_);
}
~DataPartition() {
}
......
......@@ -364,9 +364,9 @@ public:
}
is_enough_ = (cache_size_ == total_size_);
if (!is_enough_) {
mapper_ = std::vector<int>(total_size_);
inverse_mapper_ = std::vector<int>(cache_size_);
last_used_time_ = std::vector<int>(cache_size_);
mapper_.resize(total_size_);
inverse_mapper_.resize(cache_size_);
last_used_time_.resize(cache_size_);
ResetMap();
}
}
......
......@@ -105,7 +105,7 @@ void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
// push split information for all leaves
best_split_per_leaf_.resize(tree_config_->num_leaves);
data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
data_partition_->ResetLeaves(tree_config_->num_leaves);
} else {
tree_config_ = tree_config;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment