Unverified Commit b37065db authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

support to override some parameters in Dataset (#1876)

* add warnings for override parameters of Dataset

* fix pep8

* add feature_penalty

* refactor

* add R's code

* Update basic.py

* Update basic.py

* fix parameter bug

* Update lgb.Dataset.R

* fix a bug
parent f3080967
...@@ -492,6 +492,10 @@ Dataset <- R6::R6Class( ...@@ -492,6 +492,10 @@ Dataset <- R6::R6Class(
update_params = function(params) { update_params = function(params) {
# Parameter updating # Parameter updating
if (!lgb.is.null.handle(private$handle)) {
lgb.call("LGBM_DatasetUpdateParam_R", ret = NULL, private$handle, lgb.params2str(params))
return(invisible(self))
}
private$params <- modifyList(private$params, params) private$params <- modifyList(private$params, params)
return(invisible(self)) return(invisible(self))
......
...@@ -309,6 +309,14 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle, ...@@ -309,6 +309,14 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
const void** out_ptr, const void** out_ptr,
int* out_type); int* out_type);
/*!
* \brief Update parameters for a Dataset
* \param handle a instance of data matrix
* \param parameters parameters
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters);
/*! /*!
* \brief get number of data. * \brief get number of data.
* \param handle the handle to the dataset * \param handle the handle to the dataset
......
...@@ -575,6 +575,8 @@ public: ...@@ -575,6 +575,8 @@ public:
return bufs; return bufs;
} }
void ResetConfig(const char* parameters);
/*! \brief Get Number of data */ /*! \brief Get Number of data */
inline data_size_t num_data() const { return num_data_; } inline data_size_t num_data() const { return num_data_; }
...@@ -615,6 +617,11 @@ private: ...@@ -615,6 +617,11 @@ private:
std::vector<int8_t> monotone_types_; std::vector<int8_t> monotone_types_;
std::vector<double> feature_penalty_; std::vector<double> feature_penalty_;
bool is_finish_load_; bool is_finish_load_;
int max_bin_;
int bin_construct_sample_cnt_;
int min_data_in_bin_;
bool use_missing_;
bool zero_as_missing_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -171,6 +171,16 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_DatasetGetField_R(LGBM_SE handle, ...@@ -171,6 +171,16 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_DatasetGetField_R(LGBM_SE handle,
LGBM_SE field_data, LGBM_SE field_data,
LGBM_SE call_state); LGBM_SE call_state);
/*!
* \brief Update parameters for a Dataset
* \param handle a instance of data matrix
* \param parameters parameters
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT LGBM_SE LGBM_DatasetUpdateParam_R(LGBM_SE handle,
LGBM_SE params,
LGBM_SE call_state);
/*! /*!
* \brief get number of data. * \brief get number of data.
* \param handle the handle to the dataset * \param handle the handle to the dataset
......
...@@ -1070,6 +1070,8 @@ class Dataset(object): ...@@ -1070,6 +1070,8 @@ class Dataset(object):
return self return self
def _update_params(self, params): def _update_params(self, params):
if self.handle is not None and params is not None:
_safe_call(_LIB.LGBM_DatasetUpdateParam(self.handle, c_str(param_dict_to_str(params))))
if not self.params: if not self.params:
self.params = params self.params = params
else: else:
...@@ -1080,6 +1082,8 @@ class Dataset(object): ...@@ -1080,6 +1082,8 @@ class Dataset(object):
def _reverse_update_params(self): def _reverse_update_params(self):
self.params = copy.deepcopy(self.params_back_up) self.params = copy.deepcopy(self.params_back_up)
self.params_back_up = None self.params_back_up = None
if self.handle is not None and self.params is not None:
_safe_call(_LIB.LGBM_DatasetUpdateParam(self.handle, c_str(param_dict_to_str(self.params))))
return self return self
def set_field(self, field_name, data): def set_field(self, field_name, data):
......
...@@ -859,6 +859,13 @@ int LGBM_DatasetGetField(DatasetHandle handle, ...@@ -859,6 +859,13 @@ int LGBM_DatasetGetField(DatasetHandle handle,
API_END(); API_END();
} }
int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->ResetConfig(parameters);
API_END();
}
int LGBM_DatasetGetNumData(DatasetHandle handle, int LGBM_DatasetGetNumData(DatasetHandle handle,
int* out) { int* out) {
API_BEGIN(); API_BEGIN();
......
...@@ -317,6 +317,58 @@ void Dataset::Construct( ...@@ -317,6 +317,58 @@ void Dataset::Construct(
feature_penalty_.clear(); feature_penalty_.clear();
} }
} }
max_bin_ = io_config.max_bin;
min_data_in_bin_ = io_config.min_data_in_bin;
bin_construct_sample_cnt_ = io_config.bin_construct_sample_cnt;
use_missing_ = io_config.use_missing;
zero_as_missing_ = io_config.zero_as_missing;
}
void Dataset::ResetConfig(const char* parameters) {
auto param = Config::Str2Map(parameters);
Config io_config;
io_config.Set(param);
if (param.count("max_bin") && io_config.max_bin != max_bin_) {
Log::Warning("Cannot change max_bin after constructed Dataset handle.");
}
if (param.count("bin_construct_sample_cnt") && io_config.bin_construct_sample_cnt != bin_construct_sample_cnt_) {
Log::Warning("Cannot change bin_construct_sample_cnt after constructed Dataset handle.");
}
if (param.count("min_data_in_bin") && io_config.min_data_in_bin != min_data_in_bin_) {
Log::Warning("Cannot change min_data_in_bin after constructed Dataset handle.");
}
if (param.count("use_missing") && io_config.use_missing != use_missing_) {
Log::Warning("Cannot change use_missing after constructed Dataset handle.");
}
if (param.count("zero_as_missing") && io_config.zero_as_missing != zero_as_missing_) {
Log::Warning("Cannot change zero_as_missing after constructed Dataset handle.");
}
if (!io_config.monotone_constraints.empty()) {
CHECK(static_cast<size_t>(num_total_features_) == io_config.monotone_constraints.size());
monotone_types_.resize(num_features_);
for (int i = 0; i < num_total_features_; ++i) {
int inner_fidx = InnerFeatureIndex(i);
if (inner_fidx >= 0) {
monotone_types_[inner_fidx] = io_config.monotone_constraints[i];
}
}
if (ArrayArgs<int8_t>::CheckAllZero(monotone_types_)) {
monotone_types_.clear();
}
}
if (!io_config.feature_contri.empty()) {
CHECK(static_cast<size_t>(num_total_features_) == io_config.feature_contri.size());
feature_penalty_.resize(num_features_);
for (int i = 0; i < num_total_features_; ++i) {
int inner_fidx = InnerFeatureIndex(i);
if (inner_fidx >= 0) {
feature_penalty_[inner_fidx] = std::max(0.0, io_config.feature_contri[i]);
}
}
if (ArrayArgs<double>::CheckAll(feature_penalty_, 1.0)) {
feature_penalty_.clear();
}
}
} }
void Dataset::FinishLoad() { void Dataset::FinishLoad() {
...@@ -571,7 +623,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { ...@@ -571,7 +623,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_) size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_)
+ sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_) + sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_)
+ 3 * sizeof(int) * num_features_ + sizeof(uint64_t) * (num_groups_ + 1) + 2 * sizeof(int) * num_groups_ + sizeof(int8_t) * num_features_ + 3 * sizeof(int) * num_features_ + sizeof(uint64_t) * (num_groups_ + 1) + 2 * sizeof(int) * num_groups_ + sizeof(int8_t) * num_features_
+ sizeof(double) * num_features_; + sizeof(double) * num_features_ + sizeof(int) * 3 + sizeof(bool) * 2;
// size of feature names // size of feature names
for (int i = 0; i < num_total_features_; ++i) { for (int i = 0; i < num_total_features_; ++i) {
size_of_header += feature_names_[i].size() + sizeof(int); size_of_header += feature_names_[i].size() + sizeof(int);
...@@ -582,6 +634,11 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { ...@@ -582,6 +634,11 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
writer->Write(&num_features_, sizeof(num_features_)); writer->Write(&num_features_, sizeof(num_features_));
writer->Write(&num_total_features_, sizeof(num_total_features_)); writer->Write(&num_total_features_, sizeof(num_total_features_));
writer->Write(&label_idx_, sizeof(label_idx_)); writer->Write(&label_idx_, sizeof(label_idx_));
writer->Write(&max_bin_, sizeof(max_bin_));
writer->Write(&bin_construct_sample_cnt_, sizeof(bin_construct_sample_cnt_));
writer->Write(&min_data_in_bin_, sizeof(min_data_in_bin_));
writer->Write(&use_missing_, sizeof(use_missing_));
writer->Write(&zero_as_missing_, sizeof(zero_as_missing_));
writer->Write(used_feature_map_.data(), sizeof(int) * num_total_features_); writer->Write(used_feature_map_.data(), sizeof(int) * num_total_features_);
writer->Write(&num_groups_, sizeof(num_groups_)); writer->Write(&num_groups_, sizeof(num_groups_));
writer->Write(real_feature_idx_.data(), sizeof(int) * num_features_); writer->Write(real_feature_idx_.data(), sizeof(int) * num_features_);
......
...@@ -316,6 +316,16 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -316,6 +316,16 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
mem_ptr += sizeof(dataset->num_total_features_); mem_ptr += sizeof(dataset->num_total_features_);
dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr)); dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
mem_ptr += sizeof(dataset->label_idx_); mem_ptr += sizeof(dataset->label_idx_);
dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
mem_ptr += sizeof(dataset->max_bin_);
dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
mem_ptr += sizeof(dataset->bin_construct_sample_cnt_);
dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
mem_ptr += sizeof(dataset->min_data_in_bin_);
dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
mem_ptr += sizeof(dataset->use_missing_);
dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
mem_ptr += sizeof(dataset->zero_as_missing_);
const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr); const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
dataset->used_feature_map_.clear(); dataset->used_feature_map_.clear();
for (int i = 0; i < dataset->num_total_features_; ++i) { for (int i = 0; i < dataset->num_total_features_; ++i) {
......
...@@ -270,6 +270,14 @@ LGBM_SE LGBM_DatasetGetFieldSize_R(LGBM_SE handle, ...@@ -270,6 +270,14 @@ LGBM_SE LGBM_DatasetGetFieldSize_R(LGBM_SE handle,
R_API_END(); R_API_END();
} }
LGBM_SE LGBM_DatasetUpdateParam_R(LGBM_SE handle,
LGBM_SE params,
LGBM_SE call_state) {
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetUpdateParam(R_GET_PTR(handle), R_CHAR_PTR(params)));
R_API_END();
}
LGBM_SE LGBM_DatasetGetNumData_R(LGBM_SE handle, LGBM_SE out, LGBM_SE LGBM_DatasetGetNumData_R(LGBM_SE handle, LGBM_SE out,
LGBM_SE call_state) { LGBM_SE call_state) {
int nrow; int nrow;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment