Commit 61527856 authored by remcob-gr's avatar remcob-gr Committed by Nikita Titov
Browse files

When loading a binary file, take feature penalty and monotone constraints from...

When loading a binary file, take feature penalty and monotone constraints from config if given there. (#1881)

* When loading a binary file, take feature penalty from config if given there.

* When loading a binary file, take feature penalty from config if given there.

* Fix crash when num_features != num_total_features and feature_contri is given.

* Apply the same logic to monotone_types_.

* Fix indentation
parent d038aa57
......@@ -370,10 +370,22 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
}
mem_ptr += sizeof(int) * (dataset->num_groups_);
const int8_t* tmp_ptr_monotone_type = reinterpret_cast<const int8_t*>(mem_ptr);
dataset->monotone_types_.clear();
for (int i = 0; i < dataset->num_features_; ++i) {
dataset->monotone_types_.push_back(tmp_ptr_monotone_type[i]);
if(!config_.monotone_constraints.empty()) {
CHECK(dataset->num_total_features_ == config_.monotone_constraints.size());
dataset->monotone_types_.resize(dataset->num_features_);
for(int i = 0; i < dataset->num_total_features_; ++i){
int inner_fidx = dataset->InnerFeatureIndex(i);
if(inner_fidx >= 0) {
dataset->monotone_types_[inner_fidx] = config_.monotone_constraints[i];
}
}
}
else {
const int8_t* tmp_ptr_monotone_type = reinterpret_cast<const int8_t*>(mem_ptr);
dataset->monotone_types_.clear();
for (int i = 0; i < dataset->num_features_; ++i) {
dataset->monotone_types_.push_back(tmp_ptr_monotone_type[i]);
}
}
mem_ptr += sizeof(int8_t) * (dataset->num_features_);
......@@ -381,10 +393,22 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
dataset->monotone_types_.clear();
}
const double* tmp_ptr_feature_penalty = reinterpret_cast<const double*>(mem_ptr);
dataset->feature_penalty_.clear();
for (int i = 0; i < dataset->num_features_; ++i) {
dataset->feature_penalty_.push_back(tmp_ptr_feature_penalty[i]);
if(!config_.feature_contri.empty()) {
CHECK(dataset->num_total_features_ == config_.feature_contri.size());
dataset->feature_penalty_.resize(dataset->num_features_);
for(int i = 0; i < dataset->num_total_features_; ++i){
int inner_fidx = dataset->InnerFeatureIndex(i);
if(inner_fidx >= 0) {
dataset->feature_penalty_[inner_fidx] = config_.feature_contri[i];
}
}
}
else {
const double* tmp_ptr_feature_penalty = reinterpret_cast<const double*>(mem_ptr);
dataset->feature_penalty_.clear();
for (int i = 0; i < dataset->num_features_; ++i) {
dataset->feature_penalty_.push_back(tmp_ptr_feature_penalty[i]);
}
}
mem_ptr += sizeof(double) * (dataset->num_features_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment