Unverified Commit dc699574 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Refine config object (#1381)

* [WIP] refine config

* [wip] ready for the auto code generate

* auto generate config codes

* use with to open file

* fix bug

* fix pylint

* fix bug

* fix pylint

* fix bugs.

* tmp for failed test.

* fix tests.

* added nthreads alias

* added new aliases from new config.h

* fixed duplicated alias

* refactored parameter_generator.py

* added new aliases from config.h and removed remaining old names

* fix bugs & some miss alias

* added aliases

* add more descriptions.

* add comment.
parent 497e60ed
...@@ -35,30 +35,30 @@ public: ...@@ -35,30 +35,30 @@ public:
Booster(const Dataset* train_data, Booster(const Dataset* train_data,
const char* parameters) { const char* parameters) {
CHECK(train_data->num_features() > 0); CHECK(train_data->num_features() > 0);
auto param = ConfigBase::Str2Map(parameters); auto param = Config::Str2Map(parameters);
config_.Set(param); config_.Set(param);
if (config_.num_threads > 0) { if (config_.num_threads > 0) {
omp_set_num_threads(config_.num_threads); omp_set_num_threads(config_.num_threads);
} }
// create boosting // create boosting
if (config_.io_config.input_model.size() > 0) { if (config_.input_model.size() > 0) {
Log::Warning("Continued train from model is not supported for c_api,\n" Log::Warning("Continued train from model is not supported for c_api,\n"
"please use continued train with input score"); "please use continued train with input score");
} }
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr)); boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr));
train_data_ = train_data; train_data_ = train_data;
CreateObjectiveAndMetrics(); CreateObjectiveAndMetrics();
// initialize the boosting // initialize the boosting
if (config_.boosting_config.tree_learner_type == std::string("feature")) { if (config_.tree_learner == std::string("feature")) {
Log::Fatal("Do not support feature parallel in c api"); Log::Fatal("Do not support feature parallel in c api");
} }
if (Network::num_machines() == 1 && config_.boosting_config.tree_learner_type != std::string("serial")) { if (Network::num_machines() == 1 && config_.tree_learner != std::string("serial")) {
Log::Warning("Only find one worker, will switch to serial tree learner"); Log::Warning("Only find one worker, will switch to serial tree learner");
config_.boosting_config.tree_learner_type = "serial"; config_.tree_learner = "serial";
} }
boosting_->Init(&config_.boosting_config, train_data_, objective_fun_.get(), boosting_->Init(&config_, train_data_, objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
} }
...@@ -74,8 +74,8 @@ public: ...@@ -74,8 +74,8 @@ public:
void CreateObjectiveAndMetrics() { void CreateObjectiveAndMetrics() {
// create objective function // create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
config_.objective_config)); config_));
if (objective_fun_ == nullptr) { if (objective_fun_ == nullptr) {
Log::Warning("Using self-defined objective function"); Log::Warning("Using self-defined objective function");
} }
...@@ -86,9 +86,9 @@ public: ...@@ -86,9 +86,9 @@ public:
// create training metric // create training metric
train_metric_.clear(); train_metric_.clear();
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric) {
auto metric = std::unique_ptr<Metric>( auto metric = std::unique_ptr<Metric>(
Metric::CreateMetric(metric_type, config_.metric_config)); Metric::CreateMetric(metric_type, config_));
if (metric == nullptr) { continue; } if (metric == nullptr) { continue; }
metric->Init(train_data_->metadata(), train_data_->num_data()); metric->Init(train_data_->metadata(), train_data_->num_data());
train_metric_.push_back(std::move(metric)); train_metric_.push_back(std::move(metric));
...@@ -110,12 +110,12 @@ public: ...@@ -110,12 +110,12 @@ public:
void ResetConfig(const char* parameters) { void ResetConfig(const char* parameters) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto param = ConfigBase::Str2Map(parameters); auto param = Config::Str2Map(parameters);
if (param.count("num_class")) { if (param.count("num_class")) {
Log::Fatal("Cannot change num_class during training"); Log::Fatal("Cannot change num_class during training");
} }
if (param.count("boosting_type")) { if (param.count("boosting")) {
Log::Fatal("Cannot change boosting_type during training"); Log::Fatal("Cannot change boosting during training");
} }
if (param.count("metric")) { if (param.count("metric")) {
Log::Fatal("Cannot change metric during training"); Log::Fatal("Cannot change metric during training");
...@@ -128,8 +128,8 @@ public: ...@@ -128,8 +128,8 @@ public:
if (param.count("objective")) { if (param.count("objective")) {
// create objective function // create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
config_.objective_config)); config_));
if (objective_fun_ == nullptr) { if (objective_fun_ == nullptr) {
Log::Warning("Using self-defined objective function"); Log::Warning("Using self-defined objective function");
} }
...@@ -141,15 +141,15 @@ public: ...@@ -141,15 +141,15 @@ public:
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
} }
boosting_->ResetConfig(&config_.boosting_config); boosting_->ResetConfig(&config_);
} }
void AddValidData(const Dataset* valid_data) { void AddValidData(const Dataset* valid_data) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
valid_metrics_.emplace_back(); valid_metrics_.emplace_back();
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config)); auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
if (metric == nullptr) { continue; } if (metric == nullptr) { continue; }
metric->Init(valid_data->metadata(), valid_data->num_data()); metric->Init(valid_data->metadata(), valid_data->num_data());
valid_metrics_.back().push_back(std::move(metric)); valid_metrics_.back().push_back(std::move(metric));
...@@ -176,25 +176,25 @@ public: ...@@ -176,25 +176,25 @@ public:
void Predict(int num_iteration, int predict_type, int nrow, void Predict(int num_iteration, int predict_type, int nrow,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun, std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const IOConfig& config, const Config& config,
double* out_result, int64_t* out_len) { double* out_result, int64_t* out_len) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false; bool is_predict_leaf = false;
bool is_raw_score = false; bool is_raw_score = false;
bool is_predict_contrib = false; bool predict_contrib = false;
if (predict_type == C_API_PREDICT_LEAF_INDEX) { if (predict_type == C_API_PREDICT_LEAF_INDEX) {
is_predict_leaf = true; is_predict_leaf = true;
} else if (predict_type == C_API_PREDICT_RAW_SCORE) { } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
is_raw_score = true; is_raw_score = true;
} else if (predict_type == C_API_PREDICT_CONTRIB) { } else if (predict_type == C_API_PREDICT_CONTRIB) {
is_predict_contrib = true; predict_contrib = true;
} else { } else {
is_raw_score = false; is_raw_score = false;
} }
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, is_predict_contrib, Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin); config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, is_predict_contrib); int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
auto pred_fun = predictor.GetPredictFunction(); auto pred_fun = predictor.GetPredictFunction();
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
...@@ -210,22 +210,22 @@ public: ...@@ -210,22 +210,22 @@ public:
} }
void Predict(int num_iteration, int predict_type, const char* data_filename, void Predict(int num_iteration, int predict_type, const char* data_filename,
int data_has_header, const IOConfig& config, int data_has_header, const Config& config,
const char* result_filename) { const char* result_filename) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false; bool is_predict_leaf = false;
bool is_raw_score = false; bool is_raw_score = false;
bool is_predict_contrib = false; bool predict_contrib = false;
if (predict_type == C_API_PREDICT_LEAF_INDEX) { if (predict_type == C_API_PREDICT_LEAF_INDEX) {
is_predict_leaf = true; is_predict_leaf = true;
} else if (predict_type == C_API_PREDICT_RAW_SCORE) { } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
is_raw_score = true; is_raw_score = true;
} else if (predict_type == C_API_PREDICT_CONTRIB) { } else if (predict_type == C_API_PREDICT_CONTRIB) {
is_predict_contrib = true; predict_contrib = true;
} else { } else {
is_raw_score = false; is_raw_score = false;
} }
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, is_predict_contrib, Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin); config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
bool bool_data_has_header = data_has_header > 0 ? true : false; bool bool_data_has_header = data_has_header > 0 ? true : false;
predictor.Predict(data_filename, result_filename, bool_data_has_header); predictor.Predict(data_filename, result_filename, bool_data_has_header);
...@@ -300,7 +300,7 @@ private: ...@@ -300,7 +300,7 @@ private:
const Dataset* train_data_; const Dataset* train_data_;
std::unique_ptr<Boosting> boosting_; std::unique_ptr<Boosting> boosting_;
/*! \brief All configs */ /*! \brief All configs */
OverallConfig config_; Config config_;
/*! \brief Metric for training data */ /*! \brief Metric for training data */
std::vector<std::unique_ptr<Metric>> train_metric_; std::vector<std::unique_ptr<Metric>> train_metric_;
/*! \brief Metrics for validation data */ /*! \brief Metrics for validation data */
...@@ -356,13 +356,13 @@ int LGBM_DatasetCreateFromFile(const char* filename, ...@@ -356,13 +356,13 @@ int LGBM_DatasetCreateFromFile(const char* filename,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
auto param = ConfigBase::Str2Map(parameters); auto param = Config::Str2Map(parameters);
OverallConfig config; Config config;
config.Set(param); config.Set(param);
if (config.num_threads > 0) { if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads); omp_set_num_threads(config.num_threads);
} }
DatasetLoader loader(config.io_config,nullptr, 1, filename); DatasetLoader loader(config,nullptr, 1, filename);
if (reference == nullptr) { if (reference == nullptr) {
if (Network::num_machines() == 1) { if (Network::num_machines() == 1) {
*out = loader.LoadFromFile(filename, ""); *out = loader.LoadFromFile(filename, "");
...@@ -386,13 +386,13 @@ int LGBM_DatasetCreateFromSampledColumn(double** sample_data, ...@@ -386,13 +386,13 @@ int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
const char* parameters, const char* parameters,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
auto param = ConfigBase::Str2Map(parameters); auto param = Config::Str2Map(parameters);
OverallConfig config; Config config;
config.Set(param); config.Set(param);
if (config.num_threads > 0) { if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads); omp_set_num_threads(config.num_threads);
} }
DatasetLoader loader(config.io_config, nullptr, 1, nullptr); DatasetLoader loader(config, nullptr, 1, nullptr);
*out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col, *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
num_sample_row, num_sample_row,
static_cast<data_size_t>(num_total_row)); static_cast<data_size_t>(num_total_row));
...@@ -476,8 +476,8 @@ int LGBM_DatasetCreateFromMat(const void* data, ...@@ -476,8 +476,8 @@ int LGBM_DatasetCreateFromMat(const void* data,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
auto param = ConfigBase::Str2Map(parameters); auto param = Config::Str2Map(parameters);
OverallConfig config; Config config;
config.Set(param); config.Set(param);
if (config.num_threads > 0) { if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads); omp_set_num_threads(config.num_threads);
...@@ -486,8 +486,8 @@ int LGBM_DatasetCreateFromMat(const void* data, ...@@ -486,8 +486,8 @@ int LGBM_DatasetCreateFromMat(const void* data,
auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
if (reference == nullptr) { if (reference == nullptr) {
// sample data first // sample data first
Random rand(config.io_config.data_random_seed); Random rand(config.data_random_seed);
int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt); int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size()); sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values(ncol); std::vector<std::vector<double>> sample_values(ncol);
...@@ -502,7 +502,7 @@ int LGBM_DatasetCreateFromMat(const void* data, ...@@ -502,7 +502,7 @@ int LGBM_DatasetCreateFromMat(const void* data,
} }
} }
} }
DatasetLoader loader(config.io_config, nullptr, 1, nullptr); DatasetLoader loader(config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(), ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
Common::Vector2Ptr<int>(sample_idx).data(), Common::Vector2Ptr<int>(sample_idx).data(),
static_cast<int>(sample_values.size()), static_cast<int>(sample_values.size()),
...@@ -540,8 +540,8 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -540,8 +540,8 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
auto param = ConfigBase::Str2Map(parameters); auto param = Config::Str2Map(parameters);
OverallConfig config; Config config;
config.Set(param); config.Set(param);
if (config.num_threads > 0) { if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads); omp_set_num_threads(config.num_threads);
...@@ -551,8 +551,8 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -551,8 +551,8 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
int32_t nrow = static_cast<int32_t>(nindptr - 1); int32_t nrow = static_cast<int32_t>(nindptr - 1);
if (reference == nullptr) { if (reference == nullptr) {
// sample data first // sample data first
Random rand(config.io_config.data_random_seed); Random rand(config.data_random_seed);
int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt); int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size()); sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values(num_col); std::vector<std::vector<double>> sample_values(num_col);
...@@ -568,7 +568,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -568,7 +568,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
} }
} }
} }
DatasetLoader loader(config.io_config, nullptr, 1, nullptr); DatasetLoader loader(config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(), ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
Common::Vector2Ptr<int>(sample_idx).data(), Common::Vector2Ptr<int>(sample_idx).data(),
static_cast<int>(sample_values.size()), static_cast<int>(sample_values.size()),
...@@ -606,8 +606,8 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr, ...@@ -606,8 +606,8 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
auto param = ConfigBase::Str2Map(parameters); auto param = Config::Str2Map(parameters);
OverallConfig config; Config config;
config.Set(param); config.Set(param);
if (config.num_threads > 0) { if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads); omp_set_num_threads(config.num_threads);
...@@ -616,8 +616,8 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr, ...@@ -616,8 +616,8 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr,
int32_t nrow = static_cast<int32_t>(num_row); int32_t nrow = static_cast<int32_t>(num_row);
if (reference == nullptr) { if (reference == nullptr) {
// sample data first // sample data first
Random rand(config.io_config.data_random_seed); Random rand(config.data_random_seed);
int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt); int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size()); sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values(ncol_ptr - 1); std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
...@@ -637,7 +637,7 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr, ...@@ -637,7 +637,7 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr,
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
DatasetLoader loader(config.io_config, nullptr, 1, nullptr); DatasetLoader loader(config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(), ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
Common::Vector2Ptr<int>(sample_idx).data(), Common::Vector2Ptr<int>(sample_idx).data(),
static_cast<int>(sample_values.size()), static_cast<int>(sample_values.size()),
...@@ -681,8 +681,8 @@ int LGBM_DatasetGetSubset( ...@@ -681,8 +681,8 @@ int LGBM_DatasetGetSubset(
const char* parameters, const char* parameters,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
auto param = ConfigBase::Str2Map(parameters); auto param = Config::Str2Map(parameters);
OverallConfig config; Config config;
config.Set(param); config.Set(param);
if (config.num_threads > 0) { if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads); omp_set_num_threads(config.num_threads);
...@@ -996,15 +996,15 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -996,15 +996,15 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* parameter, const char* parameter,
const char* result_filename) { const char* result_filename) {
API_BEGIN(); API_BEGIN();
auto param = ConfigBase::Str2Map(parameter); auto param = Config::Str2Map(parameter);
OverallConfig config; Config config;
config.Set(param); config.Set(param);
if (config.num_threads > 0) { if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads); omp_set_num_threads(config.num_threads);
} }
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header, ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
config.io_config, result_filename); config, result_filename);
API_END(); API_END();
} }
...@@ -1035,8 +1035,8 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -1035,8 +1035,8 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
auto param = ConfigBase::Str2Map(parameter); auto param = Config::Str2Map(parameter);
OverallConfig config; Config config;
config.Set(param); config.Set(param);
if (config.num_threads > 0) { if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads); omp_set_num_threads(config.num_threads);
...@@ -1045,7 +1045,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -1045,7 +1045,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int nrow = static_cast<int>(nindptr - 1); int nrow = static_cast<int>(nindptr - 1);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
config.io_config, out_result, out_len); config, out_result, out_len);
API_END(); API_END();
} }
...@@ -1065,8 +1065,8 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -1065,8 +1065,8 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto param = ConfigBase::Str2Map(parameter); auto param = Config::Str2Map(parameter);
OverallConfig config; Config config;
config.Set(param); config.Set(param);
if (config.num_threads > 0) { if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads); omp_set_num_threads(config.num_threads);
...@@ -1096,7 +1096,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -1096,7 +1096,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
} }
return one_row; return one_row;
}; };
ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, config.io_config, ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, config,
out_result, out_len); out_result, out_len);
API_END(); API_END();
} }
...@@ -1113,8 +1113,8 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -1113,8 +1113,8 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
auto param = ConfigBase::Str2Map(parameter); auto param = Config::Str2Map(parameter);
OverallConfig config; Config config;
config.Set(param); config.Set(param);
if (config.num_threads > 0) { if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads); omp_set_num_threads(config.num_threads);
...@@ -1122,7 +1122,7 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -1122,7 +1122,7 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
config.io_config, out_result, out_len); config, out_result, out_len);
API_END(); API_END();
} }
...@@ -1203,7 +1203,7 @@ int LGBM_NetworkInit(const char* machines, ...@@ -1203,7 +1203,7 @@ int LGBM_NetworkInit(const char* machines,
int listen_time_out, int listen_time_out,
int num_machines) { int num_machines) {
API_BEGIN(); API_BEGIN();
NetworkConfig config; Config config;
config.machines = Common::RemoveQuotationSymbol(std::string(machines)); config.machines = Common::RemoveQuotationSymbol(std::string(machines));
config.local_listen_port = local_listen_port; config.local_listen_port = local_listen_port;
config.num_machines = num_machines; config.num_machines = num_machines;
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
namespace LightGBM { namespace LightGBM {
void ConfigBase::KV2Map(std::unordered_map<std::string, std::string>& params, const char* kv) { void Config::KV2Map(std::unordered_map<std::string, std::string>& params, const char* kv) {
std::vector<std::string> tmp_strs = Common::Split(kv, '='); std::vector<std::string> tmp_strs = Common::Split(kv, '=');
if (tmp_strs.size() == 2) { if (tmp_strs.size() == 2) {
std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0])); std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
...@@ -32,7 +32,7 @@ void ConfigBase::KV2Map(std::unordered_map<std::string, std::string>& params, co ...@@ -32,7 +32,7 @@ void ConfigBase::KV2Map(std::unordered_map<std::string, std::string>& params, co
} }
} }
std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* parameters) { std::unordered_map<std::string, std::string> Config::Str2Map(const char* parameters) {
std::unordered_map<std::string, std::string> params; std::unordered_map<std::string, std::string> params;
auto args = Common::Split(parameters, " \t\n\r"); auto args = Common::Split(parameters, " \t\n\r");
for (auto arg : args) { for (auto arg : args) {
...@@ -42,76 +42,76 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par ...@@ -42,76 +42,76 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par
return params; return params;
} }
void GetBoostingType(const std::unordered_map<std::string, std::string>& params, std::string* boosting_type) { void GetBoostingType(const std::unordered_map<std::string, std::string>& params, std::string* boosting) {
std::string value; std::string value;
if (ConfigBase::GetString(params, "boosting_type", &value)) { if (Config::GetString(params, "boosting", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("gbdt") || value == std::string("gbrt")) { if (value == std::string("gbdt") || value == std::string("gbrt")) {
*boosting_type = "gbdt"; *boosting = "gbdt";
} else if (value == std::string("dart")) { } else if (value == std::string("dart")) {
*boosting_type = "dart"; *boosting = "dart";
} else if (value == std::string("goss")) { } else if (value == std::string("goss")) {
*boosting_type = "goss"; *boosting = "goss";
} else if (value == std::string("rf") || value == std::string("randomforest")) { } else if (value == std::string("rf") || value == std::string("randomforest")) {
*boosting_type = "rf"; *boosting = "rf";
} else { } else {
Log::Fatal("Unknown boosting type %s", value.c_str()); Log::Fatal("Unknown boosting type %s", value.c_str());
} }
} }
} }
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective_type) { void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective) {
std::string value; std::string value;
if (ConfigBase::GetString(params, "objective", &value)) { if (Config::GetString(params, "objective", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
*objective_type = value; *objective = value;
} }
} }
void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric_types) { void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric) {
std::string value; std::string value;
if (ConfigBase::GetString(params, "metric", &value)) { if (Config::GetString(params, "metric", &value)) {
// clear old metrics // clear old metrics
metric_types->clear(); metric->clear();
// to lower // to lower
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
// split // split
std::vector<std::string> metrics = Common::Split(value.c_str(), ','); std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
// remove duplicate // remove duplicate
std::unordered_set<std::string> metric_sets; std::unordered_set<std::string> metric_sets;
for (auto& metric : metrics) { for (auto& met : metrics) {
std::transform(metric.begin(), metric.end(), metric.begin(), Common::tolower); std::transform(met.begin(), met.end(), met.begin(), Common::tolower);
if (metric_sets.count(metric) <= 0) { if (metric_sets.count(met) <= 0) {
metric_sets.insert(metric); metric_sets.insert(met);
} }
} }
for (auto& metric : metric_sets) { for (auto& met : metric_sets) {
metric_types->push_back(metric); metric->push_back(met);
} }
metric_types->shrink_to_fit(); metric->shrink_to_fit();
} }
// add names of objective function if not providing metric // add names of objective function if not providing metric
if (metric_types->empty() && value.size() == 0) { if (metric->empty() && value.size() == 0) {
if (ConfigBase::GetString(params, "objective", &value)) { if (Config::GetString(params, "objective", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
metric_types->push_back(value); metric->push_back(value);
} }
} }
} }
void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task_type) { void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task) {
std::string value; std::string value;
if (ConfigBase::GetString(params, "task", &value)) { if (Config::GetString(params, "task", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("train") || value == std::string("training")) { if (value == std::string("train") || value == std::string("training")) {
*task_type = TaskType::kTrain; *task = TaskType::kTrain;
} else if (value == std::string("predict") || value == std::string("prediction") } else if (value == std::string("predict") || value == std::string("prediction")
|| value == std::string("test")) { || value == std::string("test")) {
*task_type = TaskType::kPredict; *task = TaskType::kPredict;
} else if (value == std::string("convert_model")) { } else if (value == std::string("convert_model")) {
*task_type = TaskType::kConvertModel; *task = TaskType::kConvertModel;
} else if (value == std::string("refit") || value == std::string("refit_tree")) { } else if (value == std::string("refit") || value == std::string("refit_tree")) {
*task_type = TaskType::KRefitTree; *task = TaskType::KRefitTree;
} else { } else {
Log::Fatal("Unknown task type %s", value.c_str()); Log::Fatal("Unknown task type %s", value.c_str());
} }
...@@ -120,7 +120,7 @@ void GetTaskType(const std::unordered_map<std::string, std::string>& params, Tas ...@@ -120,7 +120,7 @@ void GetTaskType(const std::unordered_map<std::string, std::string>& params, Tas
void GetDeviceType(const std::unordered_map<std::string, std::string>& params, std::string* device_type) { void GetDeviceType(const std::unordered_map<std::string, std::string>& params, std::string* device_type) {
std::string value; std::string value;
if (ConfigBase::GetString(params, "device", &value)) { if (Config::GetString(params, "device", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("cpu")) { if (value == std::string("cpu")) {
*device_type = "cpu"; *device_type = "cpu";
...@@ -132,92 +132,89 @@ void GetDeviceType(const std::unordered_map<std::string, std::string>& params, s ...@@ -132,92 +132,89 @@ void GetDeviceType(const std::unordered_map<std::string, std::string>& params, s
} }
} }
void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params, std::string* tree_learner_type) { void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params, std::string* tree_learner) {
std::string value; std::string value;
if (ConfigBase::GetString(params, "tree_learner", &value)) { if (Config::GetString(params, "tree_learner", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("serial")) { if (value == std::string("serial")) {
*tree_learner_type = "serial"; *tree_learner = "serial";
} else if (value == std::string("feature") || value == std::string("feature_parallel")) { } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
*tree_learner_type = "feature"; *tree_learner = "feature";
} else if (value == std::string("data") || value == std::string("data_parallel")) { } else if (value == std::string("data") || value == std::string("data_parallel")) {
*tree_learner_type = "data"; *tree_learner = "data";
} else if (value == std::string("voting") || value == std::string("voting_parallel")) { } else if (value == std::string("voting") || value == std::string("voting_parallel")) {
*tree_learner_type = "voting"; *tree_learner = "voting";
} else { } else {
Log::Fatal("Unknown tree learner type %s", value.c_str()); Log::Fatal("Unknown tree learner type %s", value.c_str());
} }
} }
} }
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) { void Config::Set(const std::unordered_map<std::string, std::string>& params) {
// load main config types
GetInt(params, "num_threads", &num_threads);
GetString(params, "convert_model_language", &convert_model_language);
// generate seeds by seed. // generate seeds by seed.
if (GetInt(params, "seed", &seed)) { if (GetInt(params, "seed", &seed)) {
Random rand(seed); Random rand(seed);
int int_max = std::numeric_limits<short>::max(); int int_max = std::numeric_limits<short>::max();
io_config.data_random_seed = static_cast<int>(rand.NextShort(0, int_max)); data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.bagging_seed = static_cast<int>(rand.NextShort(0, int_max)); bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.drop_seed = static_cast<int>(rand.NextShort(0, int_max)); drop_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.tree_config.feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max)); feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
} }
GetTaskType(params, &task_type);
GetBoostingType(params, &boosting_type);
GetMetricType(params, &metric_types); GetTaskType(params, &task);
GetBoostingType(params, &boosting);
GetMetricType(params, &metric);
GetObjectiveType(params, &objective);
GetDeviceType(params, &device_type);
GetTreeLearnerType(params, &tree_learner);
// sub-config setup GetMembersFromString(params);
network_config.Set(params);
io_config.Set(params);
boosting_config.Set(params); if (valid_data_initscores.size() == 0 && valid.size() > 0) {
GetObjectiveType(params, &objective_type); valid_data_initscores = std::vector<std::string>(valid.size(), "");
objective_config.Set(params); }
metric_config.Set(params);
// check for conflicts // check for conflicts
CheckParamConflict(); CheckParamConflict();
if (io_config.verbosity == 1) { if (verbosity == 1) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info); LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
} else if (io_config.verbosity == 0) { } else if (verbosity == 0) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning); LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
} else if (io_config.verbosity >= 2) { } else if (verbosity >= 2) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug); LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
} else { } else {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal); LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
} }
} }
bool CheckMultiClassObjective(const std::string& objective_type) { bool CheckMultiClassObjective(const std::string& objective) {
return (objective_type == std::string("multiclass") return (objective == std::string("multiclass")
|| objective_type == std::string("multiclassova") || objective == std::string("multiclassova")
|| objective_type == std::string("softmax") || objective == std::string("softmax")
|| objective_type == std::string("multiclass_ova") || objective == std::string("multiclass_ova")
|| objective_type == std::string("ova") || objective == std::string("ova")
|| objective_type == std::string("ovr")); || objective == std::string("ovr"));
} }
void OverallConfig::CheckParamConflict() { void Config::CheckParamConflict() {
// check if objective_type, metric_type, and num_class match // check if objective, metric, and num_class match
int num_class_check = boosting_config.num_class; int num_class_check = num_class;
bool objective_custom = objective_type == std::string("none") || objective_type == std::string("null") || objective_type == std::string("custom"); bool objective_custom = objective == std::string("none") || objective == std::string("null") || objective == std::string("custom");
bool objective_type_multiclass = CheckMultiClassObjective(objective_type) || (objective_custom && num_class_check > 1); bool objective_type_multiclass = CheckMultiClassObjective(objective) || (objective_custom && num_class_check > 1);
if (objective_type_multiclass) { if (objective_type_multiclass) {
if (num_class_check <= 1) { if (num_class_check <= 1) {
Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training"); Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
} }
} else { } else {
if (task_type == TaskType::kTrain && num_class_check != 1) { if (task == TaskType::kTrain && num_class_check != 1) {
Log::Fatal("Number of classes must be 1 for non-multiclass training"); Log::Fatal("Number of classes must be 1 for non-multiclass training");
} }
} }
if (boosting_config.is_provide_training_metric || !io_config.valid_data_filenames.empty()) { if (is_provide_training_metric || !valid.empty()) {
for (std::string metric_type : metric_types) { for (std::string metric_type : metric) {
bool metric_type_multiclass = (CheckMultiClassObjective(metric_type) bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
|| metric_type == std::string("multi_logloss") || metric_type == std::string("multi_logloss")
|| metric_type == std::string("multi_error")); || metric_type == std::string("multi_error"));
...@@ -228,256 +225,53 @@ void OverallConfig::CheckParamConflict() { ...@@ -228,256 +225,53 @@ void OverallConfig::CheckParamConflict() {
} }
} }
if (network_config.num_machines > 1) { if (num_machines > 1) {
is_parallel = true; is_parallel = true;
} else { } else {
is_parallel = false; is_parallel = false;
boosting_config.tree_learner_type = "serial"; tree_learner = "serial";
} }
bool is_single_tree_learner = boosting_config.tree_learner_type == std::string("serial"); bool is_single_tree_learner = tree_learner == std::string("serial");
if (is_single_tree_learner) { if (is_single_tree_learner) {
is_parallel = false; is_parallel = false;
network_config.num_machines = 1; num_machines = 1;
} }
if (is_single_tree_learner || boosting_config.tree_learner_type == std::string("feature")) { if (is_single_tree_learner || tree_learner == std::string("feature")) {
is_parallel_find_bin = false; is_parallel_find_bin = false;
} else if (boosting_config.tree_learner_type == std::string("data") } else if (tree_learner == std::string("data")
|| boosting_config.tree_learner_type == std::string("voting")) { || tree_learner == std::string("voting")) {
is_parallel_find_bin = true; is_parallel_find_bin = true;
if (boosting_config.tree_config.histogram_pool_size >= 0 if (histogram_pool_size >= 0
&& boosting_config.tree_learner_type == std::string("data")) { && tree_learner == std::string("data")) {
Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f).\n" Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f).\n"
"Will disable this to reduce communication costs", "Will disable this to reduce communication costs",
boosting_config.tree_config.histogram_pool_size); histogram_pool_size);
// Change pool size to -1 (no limit) when using data parallel to reduce communication costs // Change pool size to -1 (no limit) when using data parallel to reduce communication costs
boosting_config.tree_config.histogram_pool_size = -1; histogram_pool_size = -1;
} }
} }
// Check max_depth and num_leaves // Check max_depth and num_leaves
if (boosting_config.tree_config.max_depth > 0) { if (max_depth > 0) {
int full_num_leaves = static_cast<int>(std::pow(2, boosting_config.tree_config.max_depth)); int full_num_leaves = static_cast<int>(std::pow(2, max_depth));
if (full_num_leaves > boosting_config.tree_config.num_leaves if (full_num_leaves > num_leaves
&& boosting_config.tree_config.num_leaves == kDefaultNumLeaves) { && num_leaves == kDefaultNumLeaves) {
Log::Warning("Accuracy may be bad since you didn't set num_leaves and 2^max_depth > num_leaves"); Log::Warning("Accuracy may be bad since you didn't set num_leaves and 2^max_depth > num_leaves");
} }
} }
} }
void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { std::string Config::ToString() const {
GetInt(params, "max_bin", &max_bin); std::stringstream str_buf;
CHECK(max_bin > 0); str_buf << "[boosting: " << boosting << "]\n";
GetInt(params, "num_class", &num_class); str_buf << "[objective: " << objective << "]\n";
CHECK(num_class > 0); str_buf << "[metric: " << Common::Join(metric, ",") << "]\n";
GetInt(params, "data_random_seed", &data_random_seed); str_buf << "[tree_learner: " << tree_learner << "]\n";
GetString(params, "data", &data_filename); str_buf << "[device_type: " << device_type << "]\n";
GetString(params, "init_score_file", &initscore_filename); str_buf << SaveMembersToString();
GetInt(params, "verbose", &verbosity); return str_buf.str();
GetInt(params, "num_iteration_predict", &num_iteration_predict);
GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
CHECK(bin_construct_sample_cnt > 0);
GetBool(params, "is_pre_partition", &is_pre_partition);
GetBool(params, "is_enable_sparse", &is_enable_sparse);
GetDouble(params, "sparse_threshold", &sparse_threshold);
GetBool(params, "use_two_round_loading", &use_two_round_loading);
GetBool(params, "is_save_binary_file", &is_save_binary_file);
GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
GetBool(params, "is_predict_raw_score", &is_predict_raw_score);
GetBool(params, "is_predict_leaf_index", &is_predict_leaf_index);
GetBool(params, "is_predict_contrib", &is_predict_contrib);
GetInt(params, "snapshot_freq", &snapshot_freq);
GetString(params, "output_model", &output_model);
GetString(params, "input_model", &input_model);
GetString(params, "convert_model", &convert_model);
GetString(params, "output_result", &output_result);
std::string tmp_str = "";
if (GetString(params, "monotone_constraints", &tmp_str)) {
monotone_constraints = Common::StringToArray<int8_t>(tmp_str.c_str(), ',');
}
if (GetString(params, "valid_data", &tmp_str)) {
valid_data_filenames = Common::Split(tmp_str.c_str(), ',');
}
if (GetString(params, "valid_init_score_file", &tmp_str)) {
valid_data_initscores = Common::Split(tmp_str.c_str(), ',');
} else {
valid_data_initscores = std::vector<std::string>(valid_data_filenames.size(), "");
}
CHECK(valid_data_filenames.size() == valid_data_initscores.size());
GetBool(params, "has_header", &has_header);
GetString(params, "label_column", &label_column);
GetString(params, "weight_column", &weight_column);
GetString(params, "group_column", &group_column);
GetString(params, "ignore_column", &ignore_column);
GetString(params, "categorical_column", &categorical_column);
GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
GetInt(params, "min_data_in_bin", &min_data_in_bin);
CHECK(min_data_in_bin > 0);
CHECK(min_data_in_leaf >= 0);
GetDouble(params, "max_conflict_rate", &max_conflict_rate);
CHECK(max_conflict_rate >= 0);
GetBool(params, "enable_bundle", &enable_bundle);
GetBool(params, "pred_early_stop", &pred_early_stop);
GetInt(params, "pred_early_stop_freq", &pred_early_stop_freq);
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
GetBool(params, "use_missing", &use_missing);
GetBool(params, "zero_as_missing", &zero_as_missing);
GetDeviceType(params, &device_type);
}
void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetBool(params, "is_unbalance", &is_unbalance);
GetDouble(params, "sigmoid", &sigmoid);
CHECK(sigmoid > 0);
GetDouble(params, "fair_c", &fair_c);
CHECK(fair_c > 0);
GetDouble(params, "poisson_max_delta_step", &poisson_max_delta_step);
CHECK(poisson_max_delta_step > 0);
GetInt(params, "max_position", &max_position);
CHECK(max_position > 0);
GetInt(params, "num_class", &num_class);
CHECK(num_class > 0);
GetDouble(params, "scale_pos_weight", &scale_pos_weight);
CHECK(scale_pos_weight > 0);
GetDouble(params, "alpha", &alpha);
GetBool(params, "reg_sqrt", &reg_sqrt);
GetDouble(params, "tweedie_variance_power", &tweedie_variance_power);
CHECK(tweedie_variance_power >= 1 && tweedie_variance_power < 2);
std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToArray<double>(tmp_str, ',');
} else {
// label_gain = 2^i - 1, may overflow, so we use 31 here
const int max_label = 31;
label_gain.push_back(0.0f);
for (int i = 1; i < max_label; ++i) {
label_gain.push_back(static_cast<double>((1 << i) - 1));
}
}
label_gain.shrink_to_fit();
}
void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetDouble(params, "sigmoid", &sigmoid);
CHECK(sigmoid > 0);
GetDouble(params, "fair_c", &fair_c);
CHECK(fair_c > 0);
GetInt(params, "num_class", &num_class);
CHECK(num_class > 0);
GetDouble(params, "alpha", &alpha);
GetDouble(params, "tweedie_variance_power", &tweedie_variance_power);
CHECK(tweedie_variance_power >= 1 && tweedie_variance_power < 2);
std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToArray<double>(tmp_str, ',');
} else {
// label_gain = 2^i - 1, may overflow, so we use 31 here
const int max_label = 31;
label_gain.push_back(0.0f);
for (int i = 1; i < max_label; ++i) {
label_gain.push_back(static_cast<double>((1 << i) - 1));
}
}
label_gain.shrink_to_fit();
if (GetString(params, "ndcg_eval_at", &tmp_str)) {
eval_at = Common::StringToArray<int>(tmp_str, ',');
std::sort(eval_at.begin(), eval_at.end());
for (size_t i = 0; i < eval_at.size(); ++i) {
CHECK(eval_at[i] > 0);
}
} else {
// default eval ndcg @[1-5]
for (int i = 1; i <= 5; ++i) {
eval_at.push_back(i);
}
}
eval_at.shrink_to_fit();
}
void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
CHECK(min_data_in_leaf > 0);
CHECK(min_sum_hessian_in_leaf >= 0);
GetDouble(params, "lambda_l1", &lambda_l1);
CHECK(lambda_l1 >= 0.0f);
GetDouble(params, "lambda_l2", &lambda_l2);
CHECK(lambda_l2 >= 0.0f);
GetDouble(params, "max_delta_step", &max_delta_step);
GetDouble(params, "min_gain_to_split", &min_gain_to_split);
CHECK(min_gain_to_split >= 0.0f);
GetInt(params, "num_leaves", &num_leaves);
CHECK(num_leaves > 1);
GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
GetDouble(params, "feature_fraction", &feature_fraction);
CHECK(feature_fraction > 0.0f && feature_fraction <= 1.0f);
GetDouble(params, "histogram_pool_size", &histogram_pool_size);
GetInt(params, "max_depth", &max_depth);
GetInt(params, "top_k", &top_k);
CHECK(top_k > 0);
GetInt(params, "gpu_platform_id", &gpu_platform_id);
GetInt(params, "gpu_device_id", &gpu_device_id);
GetBool(params, "gpu_use_dp", &gpu_use_dp);
GetInt(params, "max_cat_threshold", &max_cat_threshold);
GetDouble(params, "cat_l2", &cat_l2);
GetDouble(params, "cat_smooth", &cat_smooth);
GetInt(params, "min_data_per_group", &min_data_per_group);
GetInt(params, "max_cat_to_onehot", &max_cat_to_onehot);
CHECK(max_cat_threshold > 0);
CHECK(cat_l2 >= 0.0f);
CHECK(cat_smooth >= 1);
CHECK(min_data_per_group > 0);
CHECK(max_cat_to_onehot > 0);
}
void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "num_iterations", &num_iterations);
CHECK(num_iterations >= 0);
GetInt(params, "bagging_seed", &bagging_seed);
GetInt(params, "bagging_freq", &bagging_freq);
GetDouble(params, "bagging_fraction", &bagging_fraction);
CHECK(bagging_fraction > 0.0f && bagging_fraction <= 1.0f);
GetDouble(params, "learning_rate", &learning_rate);
CHECK(learning_rate > 0.0f);
GetInt(params, "early_stopping_round", &early_stopping_round);
CHECK(early_stopping_round >= 0);
GetInt(params, "output_freq", &output_freq);
CHECK(output_freq >= 0);
GetBool(params, "is_training_metric", &is_provide_training_metric);
GetInt(params, "num_class", &num_class);
CHECK(num_class > 0);
GetInt(params, "drop_seed", &drop_seed);
GetDouble(params, "drop_rate", &drop_rate);
GetDouble(params, "skip_drop", &skip_drop);
CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
GetInt(params, "max_drop", &max_drop);
GetBool(params, "xgboost_dart_mode", &xgboost_dart_mode);
GetBool(params, "uniform_drop", &uniform_drop);
GetDouble(params, "top_rate", &top_rate);
GetDouble(params, "other_rate", &other_rate);
CHECK(top_rate > 0);
CHECK(other_rate > 0);
CHECK(top_rate + other_rate <= 1.0);
GetBool(params, "boost_from_average", &boost_from_average);
GetDeviceType(params, &device_type);
GetTreeLearnerType(params, &tree_learner_type);
GetString(params, "forced_splits", &forcedsplits_filename);
tree_config.Set(params);
}
void NetworkConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "num_machines", &num_machines);
CHECK(num_machines >= 1);
GetInt(params, "local_listen_port", &local_listen_port);
CHECK(local_listen_port > 0);
GetInt(params, "time_out", &time_out);
CHECK(time_out > 0);
GetString(params, "machine_list_file", &machine_list_filename);
GetString(params, "machines", &machines);
} }
} // namespace LightGBM } // namespace LightGBM
/// This file is auto generated by LightGBM\helper\parameter_generator.py
#include<LightGBM/config.h>
namespace LightGBM {
std::unordered_map<std::string, std::string> Config::alias_table({
{"config_file", "config"},
{"task_type", "task"},
{"application", "objective"},
{"app", "objective"},
{"objective_type", "objective"},
{"boosting_type", "boosting"},
{"boost", "boosting"},
{"train", "data"},
{"train_data", "data"},
{"data_filename", "data"},
{"test", "valid"},
{"valid_data", "valid"},
{"test_data", "valid"},
{"valid_filenames", "valid"},
{"num_iteration", "num_iterations"},
{"num_tree", "num_iterations"},
{"num_trees", "num_iterations"},
{"num_round", "num_iterations"},
{"num_rounds", "num_iterations"},
{"num_boost_round", "num_iterations"},
{"n_estimators", "num_iterations"},
{"shrinkage_rate", "learning_rate"},
{"num_leaf", "num_leaves"},
{"tree", "tree_learner"},
{"tree_learner_type", "tree_learner"},
{"num_thread", "num_threads"},
{"nthread", "num_threads"},
{"nthreads", "num_threads"},
{"random_seed", "seed"},
{"min_data_per_leaf", "min_data_in_leaf"},
{"min_data", "min_data_in_leaf"},
{"min_child_samples", "min_data_in_leaf"},
{"min_sum_hessian_per_leaf", "min_sum_hessian_in_leaf"},
{"min_sum_hessian", "min_sum_hessian_in_leaf"},
{"min_hessian", "min_sum_hessian_in_leaf"},
{"min_child_weight", "min_sum_hessian_in_leaf"},
{"sub_row", "bagging_fraction"},
{"subsample", "bagging_fraction"},
{"bagging", "bagging_fraction"},
{"subsample_freq", "bagging_freq"},
{"bagging_fraction_seed", "bagging_seed"},
{"sub_feature", "feature_fraction"},
{"colsample_bytree", "feature_fraction"},
{"early_stopping_rounds", "early_stopping_round"},
{"early_stopping", "early_stopping_round"},
{"max_tree_output", "max_delta_step"},
{"max_leaf_output", "max_delta_step"},
{"reg_alpha", "lambda_l1"},
{"reg_lambda", "lambda_l2"},
{"min_split_gain", "min_gain_to_split"},
{"topk", "top_k"},
{"mc", "monotone_constraints"},
{"monotone_constraint", "monotone_constraints"},
{"forced_splits_filename", "forcedsplits_filename"},
{"forced_splits_file", "forcedsplits_filename"},
{"forced_splits", "forcedsplits_filename"},
{"model_output", "output_model"},
{"model_out", "output_model"},
{"model_input", "input_model"},
{"model_in", "input_model"},
{"predict_result", "output_result"},
{"prediction_result", "output_result"},
{"is_pre_partition", "pre_partition"},
{"is_sparse", "is_enable_sparse"},
{"enable_sparse", "is_enable_sparse"},
{"two_round_loading", "two_round"},
{"use_two_round_loading", "two_round"},
{"is_save_binary", "save_binary"},
{"is_save_binary_file", "save_binary"},
{"verbose", "verbosity"},
{"has_header", "header"},
{"label", "label_column"},
{"weight", "weight_column"},
{"query_column", "group_column"},
{"group", "group_column"},
{"query", "group_column"},
{"ignore_feature", "ignore_column"},
{"blacklist", "ignore_column"},
{"categorical_column", "categorical_feature"},
{"cat_feature", "categorical_feature"},
{"cat_column", "categorical_feature"},
{"raw_score", "predict_raw_score"},
{"is_predict_raw_score", "predict_raw_score"},
{"predict_rawscore", "predict_raw_score"},
{"leaf_index", "predict_leaf_index"},
{"is_predict_leaf_index", "predict_leaf_index"},
{"contrib", "predict_contrib"},
{"is_predict_contrib", "predict_contrib"},
{"subsample_for_bin", "bin_construct_sample_cnt"},
{"init_score_filename", "initscore_filename"},
{"init_score_file", "initscore_filename"},
{"init_score", "initscore_filename"},
{"valid_data_init_scores", "valid_data_initscores"},
{"valid_init_score_file", "valid_data_initscores"},
{"valid_init_score", "valid_data_initscores"},
{"num_classes", "num_class"},
{"unbalanced_sets", "is_unbalance"},
{"metric_types", "metric"},
{"output_freq", "metric_freq"},
{"training_metric", "is_provide_training_metric"},
{"is_training_metric", "is_provide_training_metric"},
{"train_metric", "is_provide_training_metric"},
{"ndcg_eval_at", "eval_at"},
{"ndcg_at", "eval_at"},
{"num_machine", "num_machines"},
{"local_port", "local_listen_port"},
{"mlist", "machine_list_filename"},
{"works", "machines"},
{"nodes", "machines"},
});
std::unordered_set<std::string> Config::parameter_set({
"config",
"task",
"objective",
"boosting",
"data",
"valid",
"num_iterations",
"learning_rate",
"num_leaves",
"tree_learner",
"num_threads",
"device_type",
"seed",
"max_depth",
"min_data_in_leaf",
"min_sum_hessian_in_leaf",
"bagging_fraction",
"bagging_freq",
"bagging_seed",
"feature_fraction",
"feature_fraction_seed",
"early_stopping_round",
"max_delta_step",
"lambda_l1",
"lambda_l2",
"min_gain_to_split",
"drop_rate",
"max_drop",
"skip_drop",
"xgboost_dart_mode",
"uniform_drop",
"drop_seed",
"top_rate",
"other_rate",
"min_data_per_group",
"max_cat_threshold",
"cat_l2",
"cat_smooth",
"max_cat_to_onehot",
"top_k",
"monotone_constraints",
"forcedsplits_filename",
"max_bin",
"min_data_in_bin",
"data_random_seed",
"output_model",
"input_model",
"output_result",
"pre_partition",
"is_enable_sparse",
"sparse_threshold",
"two_round",
"save_binary",
"verbosity",
"header",
"label_column",
"weight_column",
"group_column",
"ignore_column",
"categorical_feature",
"predict_raw_score",
"predict_leaf_index",
"predict_contrib",
"num_iteration_predict",
"pred_early_stop",
"pred_early_stop_freq",
"pred_early_stop_margin",
"bin_construct_sample_cnt",
"use_missing",
"zero_as_missing",
"initscore_filename",
"valid_data_initscores",
"histogram_pool_size",
"enable_load_from_binary_file",
"enable_bundle",
"max_conflict_rate",
"snapshot_freq",
"convert_model_language",
"convert_model",
"num_class",
"sigmoid",
"alpha",
"fair_c",
"poisson_max_delta_step",
"boost_from_average",
"is_unbalance",
"scale_pos_weight",
"reg_sqrt",
"tweedie_variance_power",
"label_gain",
"max_position",
"metric",
"metric_freq",
"is_provide_training_metric",
"eval_at",
"num_machines",
"local_listen_port",
"time_out",
"machine_list_filename",
"machines",
"gpu_platform_id",
"gpu_device_id",
"gpu_use_dp",
});
void Config::GetMembersFromString(const std::unordered_map<std::string, std::string>& params) {
std::string tmp_str = "";
GetString(params, "data", &data);
if (GetString(params, "valid", &tmp_str)) {
valid = Common::Split(tmp_str.c_str(), ',');
}
GetInt(params, "num_iterations", &num_iterations);
CHECK(num_iterations >=0);
GetDouble(params, "learning_rate", &learning_rate);
CHECK(learning_rate >0);
GetInt(params, "num_leaves", &num_leaves);
CHECK(num_leaves >1);
GetInt(params, "num_threads", &num_threads);
GetInt(params, "max_depth", &max_depth);
GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
CHECK(min_data_in_leaf >=0);
GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
GetDouble(params, "bagging_fraction", &bagging_fraction);
CHECK(bagging_fraction >0);
CHECK(bagging_fraction <=1.0);
GetInt(params, "bagging_freq", &bagging_freq);
GetInt(params, "bagging_seed", &bagging_seed);
GetDouble(params, "feature_fraction", &feature_fraction);
CHECK(feature_fraction >0);
CHECK(feature_fraction <=1.0);
GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
GetInt(params, "early_stopping_round", &early_stopping_round);
GetDouble(params, "max_delta_step", &max_delta_step);
GetDouble(params, "lambda_l1", &lambda_l1);
CHECK(lambda_l1 >=0);
GetDouble(params, "lambda_l2", &lambda_l2);
CHECK(lambda_l2 >=0);
GetDouble(params, "min_gain_to_split", &min_gain_to_split);
GetDouble(params, "drop_rate", &drop_rate);
CHECK(drop_rate >=0);
CHECK(drop_rate <=1.0);
GetInt(params, "max_drop", &max_drop);
GetDouble(params, "skip_drop", &skip_drop);
CHECK(skip_drop >=0);
CHECK(skip_drop <=1.0);
GetBool(params, "xgboost_dart_mode", &xgboost_dart_mode);
GetBool(params, "uniform_drop", &uniform_drop);
GetInt(params, "drop_seed", &drop_seed);
GetDouble(params, "top_rate", &top_rate);
CHECK(top_rate >=0);
CHECK(top_rate <=1.0);
GetDouble(params, "other_rate", &other_rate);
CHECK(other_rate >=0);
CHECK(other_rate <=1.0);
GetInt(params, "min_data_per_group", &min_data_per_group);
CHECK(min_data_per_group >0);
GetInt(params, "max_cat_threshold", &max_cat_threshold);
CHECK(max_cat_threshold >0);
GetDouble(params, "cat_l2", &cat_l2);
CHECK(cat_l2 >=0);
GetDouble(params, "cat_smooth", &cat_smooth);
CHECK(cat_smooth >=0);
GetInt(params, "max_cat_to_onehot", &max_cat_to_onehot);
CHECK(max_cat_to_onehot >0);
GetInt(params, "top_k", &top_k);
if (GetString(params, "monotone_constraints", &tmp_str)) {
monotone_constraints = Common::StringToArray<int8_t>(tmp_str, ',');
}
GetString(params, "forcedsplits_filename", &forcedsplits_filename);
GetInt(params, "max_bin", &max_bin);
CHECK(max_bin >1);
GetInt(params, "min_data_in_bin", &min_data_in_bin);
CHECK(min_data_in_bin >0);
GetInt(params, "data_random_seed", &data_random_seed);
GetString(params, "output_model", &output_model);
GetString(params, "input_model", &input_model);
GetString(params, "output_result", &output_result);
GetBool(params, "pre_partition", &pre_partition);
GetBool(params, "is_enable_sparse", &is_enable_sparse);
GetDouble(params, "sparse_threshold", &sparse_threshold);
CHECK(sparse_threshold >0);
CHECK(sparse_threshold <=1);
GetBool(params, "two_round", &two_round);
GetBool(params, "save_binary", &save_binary);
GetInt(params, "verbosity", &verbosity);
GetBool(params, "header", &header);
GetString(params, "label_column", &label_column);
GetString(params, "weight_column", &weight_column);
GetString(params, "group_column", &group_column);
GetString(params, "ignore_column", &ignore_column);
GetString(params, "categorical_feature", &categorical_feature);
GetBool(params, "predict_raw_score", &predict_raw_score);
GetBool(params, "predict_leaf_index", &predict_leaf_index);
GetBool(params, "predict_contrib", &predict_contrib);
GetInt(params, "num_iteration_predict", &num_iteration_predict);
GetBool(params, "pred_early_stop", &pred_early_stop);
GetInt(params, "pred_early_stop_freq", &pred_early_stop_freq);
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
CHECK(bin_construct_sample_cnt >0);
GetBool(params, "use_missing", &use_missing);
GetBool(params, "zero_as_missing", &zero_as_missing);
GetString(params, "initscore_filename", &initscore_filename);
if (GetString(params, "valid_data_initscores", &tmp_str)) {
valid_data_initscores = Common::Split(tmp_str.c_str(), ',');
}
GetDouble(params, "histogram_pool_size", &histogram_pool_size);
GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
GetBool(params, "enable_bundle", &enable_bundle);
GetDouble(params, "max_conflict_rate", &max_conflict_rate);
CHECK(max_conflict_rate >=0);
CHECK(max_conflict_rate <1);
GetInt(params, "snapshot_freq", &snapshot_freq);
GetString(params, "convert_model_language", &convert_model_language);
GetString(params, "convert_model", &convert_model);
GetInt(params, "num_class", &num_class);
GetDouble(params, "sigmoid", &sigmoid);
CHECK(sigmoid >0);
GetDouble(params, "alpha", &alpha);
GetDouble(params, "fair_c", &fair_c);
GetDouble(params, "poisson_max_delta_step", &poisson_max_delta_step);
GetBool(params, "boost_from_average", &boost_from_average);
GetBool(params, "is_unbalance", &is_unbalance);
GetDouble(params, "scale_pos_weight", &scale_pos_weight);
CHECK(scale_pos_weight >0);
GetBool(params, "reg_sqrt", &reg_sqrt);
GetDouble(params, "tweedie_variance_power", &tweedie_variance_power);
if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToArray<double>(tmp_str, ',');
}
GetInt(params, "max_position", &max_position);
CHECK(max_position >0);
GetInt(params, "metric_freq", &metric_freq);
CHECK(metric_freq >0);
GetBool(params, "is_provide_training_metric", &is_provide_training_metric);
if (GetString(params, "eval_at", &tmp_str)) {
eval_at = Common::StringToArray<int>(tmp_str, ',');
}
GetInt(params, "num_machines", &num_machines);
GetInt(params, "local_listen_port", &local_listen_port);
GetInt(params, "time_out", &time_out);
GetString(params, "machine_list_filename", &machine_list_filename);
GetString(params, "machines", &machines);
GetInt(params, "gpu_platform_id", &gpu_platform_id);
GetInt(params, "gpu_device_id", &gpu_device_id);
GetBool(params, "gpu_use_dp", &gpu_use_dp);
}
std::string Config::SaveMembersToString() const {
std::stringstream str_buf;
str_buf << "[data: " << data << "]\n";
str_buf << "[valid: " << Common::Join(valid,",") << "]\n";
str_buf << "[num_iterations: " << num_iterations << "]\n";
str_buf << "[learning_rate: " << learning_rate << "]\n";
str_buf << "[num_leaves: " << num_leaves << "]\n";
str_buf << "[num_threads: " << num_threads << "]\n";
str_buf << "[max_depth: " << max_depth << "]\n";
str_buf << "[min_data_in_leaf: " << min_data_in_leaf << "]\n";
str_buf << "[min_sum_hessian_in_leaf: " << min_sum_hessian_in_leaf << "]\n";
str_buf << "[bagging_fraction: " << bagging_fraction << "]\n";
str_buf << "[bagging_freq: " << bagging_freq << "]\n";
str_buf << "[bagging_seed: " << bagging_seed << "]\n";
str_buf << "[feature_fraction: " << feature_fraction << "]\n";
str_buf << "[feature_fraction_seed: " << feature_fraction_seed << "]\n";
str_buf << "[early_stopping_round: " << early_stopping_round << "]\n";
str_buf << "[max_delta_step: " << max_delta_step << "]\n";
str_buf << "[lambda_l1: " << lambda_l1 << "]\n";
str_buf << "[lambda_l2: " << lambda_l2 << "]\n";
str_buf << "[min_gain_to_split: " << min_gain_to_split << "]\n";
str_buf << "[drop_rate: " << drop_rate << "]\n";
str_buf << "[max_drop: " << max_drop << "]\n";
str_buf << "[skip_drop: " << skip_drop << "]\n";
str_buf << "[xgboost_dart_mode: " << xgboost_dart_mode << "]\n";
str_buf << "[uniform_drop: " << uniform_drop << "]\n";
str_buf << "[drop_seed: " << drop_seed << "]\n";
str_buf << "[top_rate: " << top_rate << "]\n";
str_buf << "[other_rate: " << other_rate << "]\n";
str_buf << "[min_data_per_group: " << min_data_per_group << "]\n";
str_buf << "[max_cat_threshold: " << max_cat_threshold << "]\n";
str_buf << "[cat_l2: " << cat_l2 << "]\n";
str_buf << "[cat_smooth: " << cat_smooth << "]\n";
str_buf << "[max_cat_to_onehot: " << max_cat_to_onehot << "]\n";
str_buf << "[top_k: " << top_k << "]\n";
str_buf << "[monotone_constraints: " << Common::Join(Common::ArrayCast<int8_t, int>(monotone_constraints),",") << "]\n";
str_buf << "[forcedsplits_filename: " << forcedsplits_filename << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n";
str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n";
str_buf << "[data_random_seed: " << data_random_seed << "]\n";
str_buf << "[output_model: " << output_model << "]\n";
str_buf << "[input_model: " << input_model << "]\n";
str_buf << "[output_result: " << output_result << "]\n";
str_buf << "[pre_partition: " << pre_partition << "]\n";
str_buf << "[is_enable_sparse: " << is_enable_sparse << "]\n";
str_buf << "[sparse_threshold: " << sparse_threshold << "]\n";
str_buf << "[two_round: " << two_round << "]\n";
str_buf << "[save_binary: " << save_binary << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[header: " << header << "]\n";
str_buf << "[label_column: " << label_column << "]\n";
str_buf << "[weight_column: " << weight_column << "]\n";
str_buf << "[group_column: " << group_column << "]\n";
str_buf << "[ignore_column: " << ignore_column << "]\n";
str_buf << "[categorical_feature: " << categorical_feature << "]\n";
str_buf << "[predict_raw_score: " << predict_raw_score << "]\n";
str_buf << "[predict_leaf_index: " << predict_leaf_index << "]\n";
str_buf << "[predict_contrib: " << predict_contrib << "]\n";
str_buf << "[num_iteration_predict: " << num_iteration_predict << "]\n";
str_buf << "[pred_early_stop: " << pred_early_stop << "]\n";
str_buf << "[pred_early_stop_freq: " << pred_early_stop_freq << "]\n";
str_buf << "[pred_early_stop_margin: " << pred_early_stop_margin << "]\n";
str_buf << "[bin_construct_sample_cnt: " << bin_construct_sample_cnt << "]\n";
str_buf << "[use_missing: " << use_missing << "]\n";
str_buf << "[zero_as_missing: " << zero_as_missing << "]\n";
str_buf << "[initscore_filename: " << initscore_filename << "]\n";
str_buf << "[valid_data_initscores: " << Common::Join(valid_data_initscores,",") << "]\n";
str_buf << "[histogram_pool_size: " << histogram_pool_size << "]\n";
str_buf << "[enable_load_from_binary_file: " << enable_load_from_binary_file << "]\n";
str_buf << "[enable_bundle: " << enable_bundle << "]\n";
str_buf << "[max_conflict_rate: " << max_conflict_rate << "]\n";
str_buf << "[snapshot_freq: " << snapshot_freq << "]\n";
str_buf << "[convert_model_language: " << convert_model_language << "]\n";
str_buf << "[convert_model: " << convert_model << "]\n";
str_buf << "[num_class: " << num_class << "]\n";
str_buf << "[sigmoid: " << sigmoid << "]\n";
str_buf << "[alpha: " << alpha << "]\n";
str_buf << "[fair_c: " << fair_c << "]\n";
str_buf << "[poisson_max_delta_step: " << poisson_max_delta_step << "]\n";
str_buf << "[boost_from_average: " << boost_from_average << "]\n";
str_buf << "[is_unbalance: " << is_unbalance << "]\n";
str_buf << "[scale_pos_weight: " << scale_pos_weight << "]\n";
str_buf << "[reg_sqrt: " << reg_sqrt << "]\n";
str_buf << "[tweedie_variance_power: " << tweedie_variance_power << "]\n";
str_buf << "[label_gain: " << Common::Join(label_gain,",") << "]\n";
str_buf << "[max_position: " << max_position << "]\n";
str_buf << "[metric_freq: " << metric_freq << "]\n";
str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n";
str_buf << "[eval_at: " << Common::Join(eval_at,",") << "]\n";
str_buf << "[num_machines: " << num_machines << "]\n";
str_buf << "[local_listen_port: " << local_listen_port << "]\n";
str_buf << "[time_out: " << time_out << "]\n";
str_buf << "[machine_list_filename: " << machine_list_filename << "]\n";
str_buf << "[machines: " << machines << "]\n";
str_buf << "[gpu_platform_id: " << gpu_platform_id << "]\n";
str_buf << "[gpu_device_id: " << gpu_device_id << "]\n";
str_buf << "[gpu_use_dp: " << gpu_use_dp << "]\n";
return str_buf.str();
}
}
...@@ -214,7 +214,7 @@ void Dataset::Construct( ...@@ -214,7 +214,7 @@ void Dataset::Construct(
int** sample_non_zero_indices, int** sample_non_zero_indices,
const int* num_per_col, const int* num_per_col,
size_t total_sample_cnt, size_t total_sample_cnt,
const IOConfig& io_config) { const Config& io_config) {
num_total_features_ = static_cast<int>(bin_mappers.size()); num_total_features_ = static_cast<int>(bin_mappers.size());
sparse_threshold_ = io_config.sparse_threshold; sparse_threshold_ = io_config.sparse_threshold;
......
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
namespace LightGBM { namespace LightGBM {
DatasetLoader::DatasetLoader(const IOConfig& io_config, const PredictFunction& predict_fun, int num_class, const char* filename) DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
:io_config_(io_config), random_(io_config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) { :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
label_idx_ = 0; label_idx_ = 0;
weight_idx_ = NO_SPECIFIC; weight_idx_ = NO_SPECIFIC;
group_idx_ = NO_SPECIFIC; group_idx_ = NO_SPECIFIC;
...@@ -24,18 +24,18 @@ void DatasetLoader::SetHeader(const char* filename) { ...@@ -24,18 +24,18 @@ void DatasetLoader::SetHeader(const char* filename) {
std::unordered_map<std::string, int> name2idx; std::unordered_map<std::string, int> name2idx;
std::string name_prefix("name:"); std::string name_prefix("name:");
if (filename != nullptr) { if (filename != nullptr) {
TextReader<data_size_t> text_reader(filename, io_config_.has_header); TextReader<data_size_t> text_reader(filename, config_.header);
// get column names // get column names
if (io_config_.has_header) { if (config_.header) {
std::string first_line = text_reader.first_line(); std::string first_line = text_reader.first_line();
feature_names_ = Common::Split(first_line.c_str(), "\t,"); feature_names_ = Common::Split(first_line.c_str(), "\t,");
} }
// load label idx first // load label idx first
if (io_config_.label_column.size() > 0) { if (config_.label_column.size() > 0) {
if (Common::StartsWith(io_config_.label_column, name_prefix)) { if (Common::StartsWith(config_.label_column, name_prefix)) {
std::string name = io_config_.label_column.substr(name_prefix.size()); std::string name = config_.label_column.substr(name_prefix.size());
label_idx_ = -1; label_idx_ = -1;
for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) { for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
if (name == feature_names_[i]) { if (name == feature_names_[i]) {
...@@ -50,7 +50,7 @@ void DatasetLoader::SetHeader(const char* filename) { ...@@ -50,7 +50,7 @@ void DatasetLoader::SetHeader(const char* filename) {
"or data file doesn't contain header", name.c_str()); "or data file doesn't contain header", name.c_str());
} }
} else { } else {
if (!Common::AtoiAndCheck(io_config_.label_column.c_str(), &label_idx_)) { if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
Log::Fatal("label_column is not a number,\n" Log::Fatal("label_column is not a number,\n"
"if you want to use a column name,\n" "if you want to use a column name,\n"
"please add the prefix \"name:\" to the column name"); "please add the prefix \"name:\" to the column name");
...@@ -68,9 +68,9 @@ void DatasetLoader::SetHeader(const char* filename) { ...@@ -68,9 +68,9 @@ void DatasetLoader::SetHeader(const char* filename) {
} }
// load ignore columns // load ignore columns
if (io_config_.ignore_column.size() > 0) { if (config_.ignore_column.size() > 0) {
if (Common::StartsWith(io_config_.ignore_column, name_prefix)) { if (Common::StartsWith(config_.ignore_column, name_prefix)) {
std::string names = io_config_.ignore_column.substr(name_prefix.size()); std::string names = config_.ignore_column.substr(name_prefix.size());
for (auto name : Common::Split(names.c_str(), ',')) { for (auto name : Common::Split(names.c_str(), ',')) {
if (name2idx.count(name) > 0) { if (name2idx.count(name) > 0) {
int tmp = name2idx[name]; int tmp = name2idx[name];
...@@ -80,7 +80,7 @@ void DatasetLoader::SetHeader(const char* filename) { ...@@ -80,7 +80,7 @@ void DatasetLoader::SetHeader(const char* filename) {
} }
} }
} else { } else {
for (auto token : Common::Split(io_config_.ignore_column.c_str(), ',')) { for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
int tmp = 0; int tmp = 0;
if (!Common::AtoiAndCheck(token.c_str(), &tmp)) { if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Log::Fatal("ignore_column is not a number,\n" Log::Fatal("ignore_column is not a number,\n"
...@@ -92,9 +92,9 @@ void DatasetLoader::SetHeader(const char* filename) { ...@@ -92,9 +92,9 @@ void DatasetLoader::SetHeader(const char* filename) {
} }
} }
// load weight idx // load weight idx
if (io_config_.weight_column.size() > 0) { if (config_.weight_column.size() > 0) {
if (Common::StartsWith(io_config_.weight_column, name_prefix)) { if (Common::StartsWith(config_.weight_column, name_prefix)) {
std::string name = io_config_.weight_column.substr(name_prefix.size()); std::string name = config_.weight_column.substr(name_prefix.size());
if (name2idx.count(name) > 0) { if (name2idx.count(name) > 0) {
weight_idx_ = name2idx[name]; weight_idx_ = name2idx[name];
Log::Info("Using column %s as weight", name.c_str()); Log::Info("Using column %s as weight", name.c_str());
...@@ -102,7 +102,7 @@ void DatasetLoader::SetHeader(const char* filename) { ...@@ -102,7 +102,7 @@ void DatasetLoader::SetHeader(const char* filename) {
Log::Fatal("Could not find weight column %s in data file", name.c_str()); Log::Fatal("Could not find weight column %s in data file", name.c_str());
} }
} else { } else {
if (!Common::AtoiAndCheck(io_config_.weight_column.c_str(), &weight_idx_)) { if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
Log::Fatal("weight_column is not a number,\n" Log::Fatal("weight_column is not a number,\n"
"if you want to use a column name,\n" "if you want to use a column name,\n"
"please add the prefix \"name:\" to the column name"); "please add the prefix \"name:\" to the column name");
...@@ -112,9 +112,9 @@ void DatasetLoader::SetHeader(const char* filename) { ...@@ -112,9 +112,9 @@ void DatasetLoader::SetHeader(const char* filename) {
ignore_features_.emplace(weight_idx_); ignore_features_.emplace(weight_idx_);
} }
// load group idx // load group idx
if (io_config_.group_column.size() > 0) { if (config_.group_column.size() > 0) {
if (Common::StartsWith(io_config_.group_column, name_prefix)) { if (Common::StartsWith(config_.group_column, name_prefix)) {
std::string name = io_config_.group_column.substr(name_prefix.size()); std::string name = config_.group_column.substr(name_prefix.size());
if (name2idx.count(name) > 0) { if (name2idx.count(name) > 0) {
group_idx_ = name2idx[name]; group_idx_ = name2idx[name];
Log::Info("Using column %s as group/query id", name.c_str()); Log::Info("Using column %s as group/query id", name.c_str());
...@@ -122,7 +122,7 @@ void DatasetLoader::SetHeader(const char* filename) { ...@@ -122,7 +122,7 @@ void DatasetLoader::SetHeader(const char* filename) {
Log::Fatal("Could not find group/query column %s in data file", name.c_str()); Log::Fatal("Could not find group/query column %s in data file", name.c_str());
} }
} else { } else {
if (!Common::AtoiAndCheck(io_config_.group_column.c_str(), &group_idx_)) { if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
Log::Fatal("group_column is not a number,\n" Log::Fatal("group_column is not a number,\n"
"if you want to use a column name,\n" "if you want to use a column name,\n"
"please add the prefix \"name:\" to the column name"); "please add the prefix \"name:\" to the column name");
...@@ -132,22 +132,22 @@ void DatasetLoader::SetHeader(const char* filename) { ...@@ -132,22 +132,22 @@ void DatasetLoader::SetHeader(const char* filename) {
ignore_features_.emplace(group_idx_); ignore_features_.emplace(group_idx_);
} }
} }
if (io_config_.categorical_column.size() > 0) { if (config_.categorical_feature.size() > 0) {
if (Common::StartsWith(io_config_.categorical_column, name_prefix)) { if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
std::string names = io_config_.categorical_column.substr(name_prefix.size()); std::string names = config_.categorical_feature.substr(name_prefix.size());
for (auto name : Common::Split(names.c_str(), ',')) { for (auto name : Common::Split(names.c_str(), ',')) {
if (name2idx.count(name) > 0) { if (name2idx.count(name) > 0) {
int tmp = name2idx[name]; int tmp = name2idx[name];
categorical_features_.emplace(tmp); categorical_features_.emplace(tmp);
} else { } else {
Log::Fatal("Could not find categorical_column %s in data file", name.c_str()); Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
} }
} }
} else { } else {
for (auto token : Common::Split(io_config_.categorical_column.c_str(), ',')) { for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
int tmp = 0; int tmp = 0;
if (!Common::AtoiAndCheck(token.c_str(), &tmp)) { if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Log::Fatal("categorical_column is not a number,\n" Log::Fatal("categorical_feature is not a number,\n"
"if you want to use a column name,\n" "if you want to use a column name,\n"
"please add the prefix \"name:\" to the column name"); "please add the prefix \"name:\" to the column name");
} }
...@@ -159,7 +159,7 @@ void DatasetLoader::SetHeader(const char* filename) { ...@@ -159,7 +159,7 @@ void DatasetLoader::SetHeader(const char* filename) {
Dataset* DatasetLoader::LoadFromFile(const char* filename, const char* initscore_file, int rank, int num_machines) { Dataset* DatasetLoader::LoadFromFile(const char* filename, const char* initscore_file, int rank, int num_machines) {
// don't support query id in data file when training in parallel // don't support query id in data file when training in parallel
if (num_machines > 1 && !io_config_.is_pre_partition) { if (num_machines > 1 && !config_.pre_partition) {
if (group_idx_ > 0) { if (group_idx_ > 0) {
Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training.\n" Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training.\n"
"Please use an additional query file or pre-partition the data"); "Please use an additional query file or pre-partition the data");
...@@ -170,14 +170,14 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, const char* initscore ...@@ -170,14 +170,14 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, const char* initscore
std::vector<data_size_t> used_data_indices; std::vector<data_size_t> used_data_indices;
auto bin_filename = CheckCanLoadFromBin(filename); auto bin_filename = CheckCanLoadFromBin(filename);
if (bin_filename.size() == 0) { if (bin_filename.size() == 0) {
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, io_config_.has_header, 0, label_idx_)); auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
if (parser == nullptr) { if (parser == nullptr) {
Log::Fatal("Could not recognize data format of %s", filename); Log::Fatal("Could not recognize data format of %s", filename);
} }
dataset->data_filename_ = filename; dataset->data_filename_ = filename;
dataset->label_idx_ = label_idx_; dataset->label_idx_ = label_idx_;
dataset->metadata_.Init(filename, initscore_file); dataset->metadata_.Init(filename, initscore_file);
if (!io_config_.use_two_round_loading) { if (!config_.two_round) {
// read data to memory // read data to memory
auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices); auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
dataset->num_data_ = static_cast<data_size_t>(text_data.size()); dataset->num_data_ = static_cast<data_size_t>(text_data.size());
...@@ -225,14 +225,14 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, ...@@ -225,14 +225,14 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
auto dataset = std::unique_ptr<Dataset>(new Dataset()); auto dataset = std::unique_ptr<Dataset>(new Dataset());
auto bin_filename = CheckCanLoadFromBin(filename); auto bin_filename = CheckCanLoadFromBin(filename);
if (bin_filename.size() == 0) { if (bin_filename.size() == 0) {
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, io_config_.has_header, 0, label_idx_)); auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
if (parser == nullptr) { if (parser == nullptr) {
Log::Fatal("Could not recognize data format of %s", filename); Log::Fatal("Could not recognize data format of %s", filename);
} }
dataset->data_filename_ = filename; dataset->data_filename_ = filename;
dataset->label_idx_ = label_idx_; dataset->label_idx_ = label_idx_;
dataset->metadata_.Init(filename, initscore_file); dataset->metadata_.Init(filename, initscore_file);
if (!io_config_.use_two_round_loading) { if (!config_.two_round) {
// read data in memory // read data in memory
auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices); auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
dataset->num_data_ = static_cast<data_size_t>(text_data.size()); dataset->num_data_ = static_cast<data_size_t>(text_data.size());
...@@ -243,7 +243,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, ...@@ -243,7 +243,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
ExtractFeaturesFromMemory(text_data, parser.get(), dataset.get()); ExtractFeaturesFromMemory(text_data, parser.get(), dataset.get());
text_data.clear(); text_data.clear();
} else { } else {
TextReader<data_size_t> text_reader(filename, io_config_.has_header); TextReader<data_size_t> text_reader(filename, config_.header);
// Get number of lines of data file // Get number of lines of data file
dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine()); dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
num_global_data = dataset->num_data_; num_global_data = dataset->num_data_;
...@@ -421,7 +421,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -421,7 +421,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
*num_global_data = dataset->num_data_; *num_global_data = dataset->num_data_;
used_data_indices->clear(); used_data_indices->clear();
// sample local used data if need to partition // sample local used data if need to partition
if (num_machines > 1 && !io_config_.is_pre_partition) { if (num_machines > 1 && !config_.pre_partition) {
const data_size_t* query_boundaries = dataset->metadata_.query_boundaries(); const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
if (query_boundaries == nullptr) { if (query_boundaries == nullptr) {
// if not contain query file, minimal sample unit is one record // if not contain query file, minimal sample unit is one record
...@@ -500,7 +500,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, ...@@ -500,7 +500,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
} }
const data_size_t filter_cnt = static_cast<data_size_t>( const data_size_t filter_cnt = static_cast<data_size_t>(
static_cast<double>(io_config_.min_data_in_leaf * total_sample_size) / num_data); static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
if (Network::num_machines() == 1) { if (Network::num_machines() == 1) {
// if only one machine, find bin locally // if only one machine, find bin locally
OMP_INIT_EX(); OMP_INIT_EX();
...@@ -517,7 +517,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, ...@@ -517,7 +517,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
} }
bin_mappers[i].reset(new BinMapper()); bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size, bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
io_config_.max_bin, io_config_.min_data_in_bin, filter_cnt, bin_type, io_config_.use_missing, io_config_.zero_as_missing); config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
...@@ -554,7 +554,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, ...@@ -554,7 +554,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
} }
bin_mappers[i].reset(new BinMapper()); bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i], total_sample_size, bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i], total_sample_size,
io_config_.max_bin, io_config_.min_data_in_bin, filter_cnt, bin_type, io_config_.use_missing, io_config_.zero_as_missing); config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
...@@ -605,7 +605,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, ...@@ -605,7 +605,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
} }
} }
auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data)); auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
dataset->Construct(bin_mappers, sample_indices, num_per_col, total_sample_size, io_config_); dataset->Construct(bin_mappers, sample_indices, num_per_col, total_sample_size, config_);
dataset->set_feature_names(feature_names_); dataset->set_feature_names(feature_names_);
return dataset.release(); return dataset.release();
} }
...@@ -650,9 +650,9 @@ void DatasetLoader::CheckDataset(const Dataset* dataset) { ...@@ -650,9 +650,9 @@ void DatasetLoader::CheckDataset(const Dataset* dataset) {
std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata, std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
int rank, int num_machines, int* num_global_data, int rank, int num_machines, int* num_global_data,
std::vector<data_size_t>* used_data_indices) { std::vector<data_size_t>* used_data_indices) {
TextReader<data_size_t> text_reader(filename, io_config_.has_header); TextReader<data_size_t> text_reader(filename, config_.header);
used_data_indices->clear(); used_data_indices->clear();
if (num_machines == 1 || io_config_.is_pre_partition) { if (num_machines == 1 || config_.pre_partition) {
// read all lines // read all lines
*num_global_data = text_reader.ReadAllLines(); *num_global_data = text_reader.ReadAllLines();
} else { // need partition data } else { // need partition data
...@@ -696,7 +696,7 @@ std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filenam ...@@ -696,7 +696,7 @@ std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filenam
} }
std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) { std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
int sample_cnt = io_config_.bin_construct_sample_cnt; int sample_cnt = config_.bin_construct_sample_cnt;
if (static_cast<size_t>(sample_cnt) > data.size()) { if (static_cast<size_t>(sample_cnt) > data.size()) {
sample_cnt = static_cast<int>(data.size()); sample_cnt = static_cast<int>(data.size());
} }
...@@ -710,10 +710,10 @@ std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vect ...@@ -710,10 +710,10 @@ std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vect
} }
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata, int rank, int num_machines, int* num_global_data, std::vector<data_size_t>* used_data_indices) { std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata, int rank, int num_machines, int* num_global_data, std::vector<data_size_t>* used_data_indices) {
const data_size_t sample_cnt = static_cast<data_size_t>(io_config_.bin_construct_sample_cnt); const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
TextReader<data_size_t> text_reader(filename, io_config_.has_header); TextReader<data_size_t> text_reader(filename, config_.header);
std::vector<std::string> out_data; std::vector<std::string> out_data;
if (num_machines == 1 || io_config_.is_pre_partition) { if (num_machines == 1 || config_.pre_partition) {
*num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(random_, sample_cnt, &out_data)); *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(random_, sample_cnt, &out_data));
} else { // need partition data } else { // need partition data
// get query data // get query data
...@@ -804,7 +804,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -804,7 +804,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
dataset->set_feature_names(feature_names_); dataset->set_feature_names(feature_names_);
std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_); std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
const data_size_t filter_cnt = static_cast<data_size_t>( const data_size_t filter_cnt = static_cast<data_size_t>(
static_cast<double>(io_config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_); static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
// start find bins // start find bins
if (num_machines == 1) { if (num_machines == 1) {
...@@ -823,7 +823,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -823,7 +823,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
} }
bin_mappers[i].reset(new BinMapper()); bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()), bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
sample_data.size(), io_config_.max_bin, io_config_.min_data_in_bin, filter_cnt, bin_type, io_config_.use_missing, io_config_.zero_as_missing); sample_data.size(), config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
...@@ -860,7 +860,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -860,7 +860,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
} }
bin_mappers[i].reset(new BinMapper()); bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(), static_cast<int>(sample_values[start[rank] + i].size()), bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(), static_cast<int>(sample_values[start[rank] + i].size()),
sample_data.size(), io_config_.max_bin, io_config_.min_data_in_bin, filter_cnt, bin_type, io_config_.use_missing, io_config_.zero_as_missing); sample_data.size(), config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
...@@ -912,7 +912,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -912,7 +912,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
} }
sample_values.clear(); sample_values.clear();
dataset->Construct(bin_mappers, Common::Vector2Ptr<int>(sample_indices).data(), dataset->Construct(bin_mappers, Common::Vector2Ptr<int>(sample_indices).data(),
Common::VectorSize<int>(sample_indices).data(), sample_data.size(), io_config_); Common::VectorSize<int>(sample_indices).data(), sample_data.size(), config_);
} }
/*! \brief Extract local features from memory */ /*! \brief Extract local features from memory */
...@@ -1056,7 +1056,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* ...@@ -1056,7 +1056,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
} }
OMP_THROW_EX(); OMP_THROW_EX();
}; };
TextReader<data_size_t> text_reader(filename, io_config_.has_header); TextReader<data_size_t> text_reader(filename, config_.header);
if (!used_data_indices.empty()) { if (!used_data_indices.empty()) {
// only need part of data // only need part of data
text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun); text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
......
...@@ -84,7 +84,7 @@ void getline(std::stringstream& ss, std::string& line, const VirtualFileReader* ...@@ -84,7 +84,7 @@ void getline(std::stringstream& ss, std::string& line, const VirtualFileReader*
} }
} }
Parser* Parser::CreateParser(const char* filename, bool has_header, int num_features, int label_idx) { Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx) {
auto reader = VirtualFileReader::Make(filename); auto reader = VirtualFileReader::Make(filename);
if (!reader->Init()) { if (!reader->Init()) {
Log::Fatal("Data file %s doesn't exist", filename); Log::Fatal("Data file %s doesn't exist", filename);
...@@ -98,7 +98,7 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat ...@@ -98,7 +98,7 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat
} }
std::stringstream tmp_file(std::string(buffer.data(), read_len)); std::stringstream tmp_file(std::string(buffer.data(), read_len));
if (has_header) { if (header) {
if (!tmp_file.eof()) { if (!tmp_file.eof()) {
getline(tmp_file, line1, reader.get(), buffer, buffer_size); getline(tmp_file, line1, reader.get(), buffer, buffer_size);
} }
......
...@@ -352,10 +352,10 @@ std::string Tree::CategoricalDecisionIfElse(int node) const { ...@@ -352,10 +352,10 @@ std::string Tree::CategoricalDecisionIfElse(int node) const {
return str_buf.str(); return str_buf.str();
} }
std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) const { std::string Tree::ToIfElse(int index, bool predict_leaf_index) const {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << "double PredictTree" << index; str_buf << "double PredictTree" << index;
if (is_predict_leaf_index) { if (predict_leaf_index) {
str_buf << "Leaf"; str_buf << "Leaf";
} }
str_buf << "(const double* arr) { "; str_buf << "(const double* arr) { ";
...@@ -375,13 +375,13 @@ std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) const { ...@@ -375,13 +375,13 @@ std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) const {
if (num_cat_ > 0) { if (num_cat_ > 0) {
str_buf << "int int_fval = 0; "; str_buf << "int int_fval = 0; ";
} }
str_buf << NodeToIfElse(0, is_predict_leaf_index); str_buf << NodeToIfElse(0, predict_leaf_index);
} }
str_buf << " }" << '\n'; str_buf << " }" << '\n';
//Predict func by Map to ifelse //Predict func by Map to ifelse
str_buf << "double PredictTree" << index; str_buf << "double PredictTree" << index;
if (is_predict_leaf_index) { if (predict_leaf_index) {
str_buf << "LeafByMap"; str_buf << "LeafByMap";
} else { } else {
str_buf << "ByMap"; str_buf << "ByMap";
...@@ -403,14 +403,14 @@ std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) const { ...@@ -403,14 +403,14 @@ std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) const {
if (num_cat_ > 0) { if (num_cat_ > 0) {
str_buf << "int int_fval = 0; "; str_buf << "int int_fval = 0; ";
} }
str_buf << NodeToIfElseByMap(0, is_predict_leaf_index); str_buf << NodeToIfElseByMap(0, predict_leaf_index);
} }
str_buf << " }" << '\n'; str_buf << " }" << '\n';
return str_buf.str(); return str_buf.str();
} }
std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) const { std::string Tree::NodeToIfElse(int index, bool predict_leaf_index) const {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2); str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
if (index >= 0) { if (index >= 0) {
...@@ -422,15 +422,15 @@ std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) const { ...@@ -422,15 +422,15 @@ std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) const {
str_buf << CategoricalDecisionIfElse(index); str_buf << CategoricalDecisionIfElse(index);
} }
// left subtree // left subtree
str_buf << NodeToIfElse(left_child_[index], is_predict_leaf_index); str_buf << NodeToIfElse(left_child_[index], predict_leaf_index);
str_buf << " } else { "; str_buf << " } else { ";
// right subtree // right subtree
str_buf << NodeToIfElse(right_child_[index], is_predict_leaf_index); str_buf << NodeToIfElse(right_child_[index], predict_leaf_index);
str_buf << " }"; str_buf << " }";
} else { } else {
// leaf // leaf
str_buf << "return "; str_buf << "return ";
if (is_predict_leaf_index) { if (predict_leaf_index) {
str_buf << ~index; str_buf << ~index;
} else { } else {
str_buf << leaf_value_[~index]; str_buf << leaf_value_[~index];
...@@ -441,7 +441,7 @@ std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) const { ...@@ -441,7 +441,7 @@ std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) const {
return str_buf.str(); return str_buf.str();
} }
std::string Tree::NodeToIfElseByMap(int index, bool is_predict_leaf_index) const { std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2); str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
if (index >= 0) { if (index >= 0) {
...@@ -453,15 +453,15 @@ std::string Tree::NodeToIfElseByMap(int index, bool is_predict_leaf_index) const ...@@ -453,15 +453,15 @@ std::string Tree::NodeToIfElseByMap(int index, bool is_predict_leaf_index) const
str_buf << CategoricalDecisionIfElse(index); str_buf << CategoricalDecisionIfElse(index);
} }
// left subtree // left subtree
str_buf << NodeToIfElseByMap(left_child_[index], is_predict_leaf_index); str_buf << NodeToIfElseByMap(left_child_[index], predict_leaf_index);
str_buf << " } else { "; str_buf << " } else { ";
// right subtree // right subtree
str_buf << NodeToIfElseByMap(right_child_[index], is_predict_leaf_index); str_buf << NodeToIfElseByMap(right_child_[index], predict_leaf_index);
str_buf << " }"; str_buf << " }";
} else { } else {
// leaf // leaf
str_buf << "return "; str_buf << "return ";
if (is_predict_leaf_index) { if (predict_leaf_index) {
str_buf << ~index; str_buf << ~index;
} else { } else {
str_buf << leaf_value_[~index]; str_buf << leaf_value_[~index];
......
...@@ -19,7 +19,7 @@ namespace LightGBM { ...@@ -19,7 +19,7 @@ namespace LightGBM {
template<typename PointWiseLossCalculator> template<typename PointWiseLossCalculator>
class BinaryMetric: public Metric { class BinaryMetric: public Metric {
public: public:
explicit BinaryMetric(const MetricConfig&) { explicit BinaryMetric(const Config&) {
} }
...@@ -112,7 +112,7 @@ private: ...@@ -112,7 +112,7 @@ private:
*/ */
class BinaryLoglossMetric: public BinaryMetric<BinaryLoglossMetric> { class BinaryLoglossMetric: public BinaryMetric<BinaryLoglossMetric> {
public: public:
explicit BinaryLoglossMetric(const MetricConfig& config) :BinaryMetric<BinaryLoglossMetric>(config) {} explicit BinaryLoglossMetric(const Config& config) :BinaryMetric<BinaryLoglossMetric>(config) {}
inline static double LossOnPoint(label_t label, double prob) { inline static double LossOnPoint(label_t label, double prob) {
if (label <= 0) { if (label <= 0) {
...@@ -136,7 +136,7 @@ public: ...@@ -136,7 +136,7 @@ public:
*/ */
class BinaryErrorMetric: public BinaryMetric<BinaryErrorMetric> { class BinaryErrorMetric: public BinaryMetric<BinaryErrorMetric> {
public: public:
explicit BinaryErrorMetric(const MetricConfig& config) :BinaryMetric<BinaryErrorMetric>(config) {} explicit BinaryErrorMetric(const Config& config) :BinaryMetric<BinaryErrorMetric>(config) {}
inline static double LossOnPoint(label_t label, double prob) { inline static double LossOnPoint(label_t label, double prob) {
if (prob <= 0.5f) { if (prob <= 0.5f) {
...@@ -156,7 +156,7 @@ public: ...@@ -156,7 +156,7 @@ public:
*/ */
class AUCMetric: public Metric { class AUCMetric: public Metric {
public: public:
explicit AUCMetric(const MetricConfig&) { explicit AUCMetric(const Config&) {
} }
......
...@@ -14,7 +14,30 @@ std::vector<double> DCGCalculator::label_gain_; ...@@ -14,7 +14,30 @@ std::vector<double> DCGCalculator::label_gain_;
std::vector<double> DCGCalculator::discount_; std::vector<double> DCGCalculator::discount_;
const data_size_t DCGCalculator::kMaxPosition = 10000; const data_size_t DCGCalculator::kMaxPosition = 10000;
void DCGCalculator::Init(std::vector<double> input_label_gain) {
void DCGCalculator::DefaultEvalAt(std::vector<int>* eval_at) {
if (eval_at->empty()) {
for (int i = 1; i <= 5; ++i) {
eval_at->push_back(i);
}
} else {
for (size_t i = 0; i < eval_at->size(); ++i) {
CHECK(eval_at->at(i) > 0);
}
}
}
void DCGCalculator::DefaultLabelGain(std::vector<double>* label_gain) {
if (!label_gain->empty()) { return; }
// label_gain = 2^i - 1, may overflow, so we use 31 here
const int max_label = 31;
label_gain->push_back(0.0f);
for (int i = 1; i < max_label; ++i) {
label_gain->push_back(static_cast<double>((1 << i) - 1));
}
}
void DCGCalculator::Init(const std::vector<double>& input_label_gain) {
label_gain_.resize(input_label_gain.size()); label_gain_.resize(input_label_gain.size());
for(size_t i = 0;i < input_label_gain.size();++i){ for(size_t i = 0;i < input_label_gain.size();++i){
label_gain_[i] = static_cast<double>(input_label_gain[i]); label_gain_[i] = static_cast<double>(input_label_gain[i]);
......
...@@ -15,11 +15,10 @@ namespace LightGBM { ...@@ -15,11 +15,10 @@ namespace LightGBM {
class MapMetric:public Metric { class MapMetric:public Metric {
public: public:
explicit MapMetric(const MetricConfig& config) { explicit MapMetric(const Config& config) {
// get eval position // get eval position
for (auto k : config.eval_at) { eval_at_ = config.eval_at;
eval_at_.push_back(static_cast<data_size_t>(k)); DCGCalculator::DefaultEvalAt(&eval_at_);
}
// get number of threads // get number of threads
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace LightGBM { namespace LightGBM {
Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config) { Metric* Metric::CreateMetric(const std::string& type, const Config& config) {
if (type == std::string("regression") || type == std::string("regression_l2") || type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) { if (type == std::string("regression") || type == std::string("regression_l2") || type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
return new L2Metric(config); return new L2Metric(config);
} else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) { } else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
......
...@@ -15,7 +15,7 @@ namespace LightGBM { ...@@ -15,7 +15,7 @@ namespace LightGBM {
template<typename PointWiseLossCalculator> template<typename PointWiseLossCalculator>
class MulticlassMetric: public Metric { class MulticlassMetric: public Metric {
public: public:
explicit MulticlassMetric(const MetricConfig& config) { explicit MulticlassMetric(const Config& config) {
num_class_ = config.num_class; num_class_ = config.num_class;
} }
...@@ -131,7 +131,7 @@ private: ...@@ -131,7 +131,7 @@ private:
/*! \brief L2 loss for multiclass task */ /*! \brief L2 loss for multiclass task */
class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> { class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> {
public: public:
explicit MultiErrorMetric(const MetricConfig& config) :MulticlassMetric<MultiErrorMetric>(config) {} explicit MultiErrorMetric(const Config& config) :MulticlassMetric<MultiErrorMetric>(config) {}
inline static double LossOnPoint(label_t label, std::vector<double>& score) { inline static double LossOnPoint(label_t label, std::vector<double>& score) {
size_t k = static_cast<size_t>(label); size_t k = static_cast<size_t>(label);
...@@ -151,7 +151,7 @@ public: ...@@ -151,7 +151,7 @@ public:
/*! \brief Logloss for multiclass task */ /*! \brief Logloss for multiclass task */
class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetric> { class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetric> {
public: public:
explicit MultiSoftmaxLoglossMetric(const MetricConfig& config) :MulticlassMetric<MultiSoftmaxLoglossMetric>(config) {} explicit MultiSoftmaxLoglossMetric(const Config& config) :MulticlassMetric<MultiSoftmaxLoglossMetric>(config) {}
inline static double LossOnPoint(label_t label, std::vector<double>& score) { inline static double LossOnPoint(label_t label, std::vector<double>& score) {
size_t k = static_cast<size_t>(label); size_t k = static_cast<size_t>(label);
......
...@@ -15,14 +15,14 @@ namespace LightGBM { ...@@ -15,14 +15,14 @@ namespace LightGBM {
class NDCGMetric:public Metric { class NDCGMetric:public Metric {
public: public:
explicit NDCGMetric(const MetricConfig& config) { explicit NDCGMetric(const Config& config) {
// get eval position // get eval position
for (auto k : config.eval_at) { eval_at_ = config.eval_at;
eval_at_.push_back(static_cast<data_size_t>(k)); auto label_gain = config.label_gain;
} DCGCalculator::DefaultEvalAt(&eval_at_);
eval_at_.shrink_to_fit(); DCGCalculator::DefaultLabelGain(&label_gain);
// initialize DCG calculator // initialize DCG calculator
DCGCalculator::Init(config.label_gain); DCGCalculator::Init(label_gain);
// get number of threads // get number of threads
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
......
...@@ -15,7 +15,7 @@ namespace LightGBM { ...@@ -15,7 +15,7 @@ namespace LightGBM {
template<typename PointWiseLossCalculator> template<typename PointWiseLossCalculator>
class RegressionMetric: public Metric { class RegressionMetric: public Metric {
public: public:
explicit RegressionMetric(const MetricConfig& config) :config_(config) { explicit RegressionMetric(const Config& config) :config_(config) {
} }
virtual ~RegressionMetric() { virtual ~RegressionMetric() {
...@@ -100,16 +100,16 @@ private: ...@@ -100,16 +100,16 @@ private:
/*! \brief Sum weights */ /*! \brief Sum weights */
double sum_weights_; double sum_weights_;
/*! \brief Name of this test set */ /*! \brief Name of this test set */
MetricConfig config_; Config config_;
std::vector<std::string> name_; std::vector<std::string> name_;
}; };
/*! \brief RMSE loss for regression task */ /*! \brief RMSE loss for regression task */
class RMSEMetric: public RegressionMetric<RMSEMetric> { class RMSEMetric: public RegressionMetric<RMSEMetric> {
public: public:
explicit RMSEMetric(const MetricConfig& config) :RegressionMetric<RMSEMetric>(config) {} explicit RMSEMetric(const Config& config) :RegressionMetric<RMSEMetric>(config) {}
inline static double LossOnPoint(label_t label, double score, const MetricConfig&) { inline static double LossOnPoint(label_t label, double score, const Config&) {
return (score - label)*(score - label); return (score - label)*(score - label);
} }
...@@ -126,9 +126,9 @@ public: ...@@ -126,9 +126,9 @@ public:
/*! \brief L2 loss for regression task */ /*! \brief L2 loss for regression task */
class L2Metric: public RegressionMetric<L2Metric> { class L2Metric: public RegressionMetric<L2Metric> {
public: public:
explicit L2Metric(const MetricConfig& config) :RegressionMetric<L2Metric>(config) {} explicit L2Metric(const Config& config) :RegressionMetric<L2Metric>(config) {}
inline static double LossOnPoint(label_t label, double score, const MetricConfig&) { inline static double LossOnPoint(label_t label, double score, const Config&) {
return (score - label)*(score - label); return (score - label)*(score - label);
} }
...@@ -140,10 +140,10 @@ public: ...@@ -140,10 +140,10 @@ public:
/*! \brief L2 loss for regression task */ /*! \brief L2 loss for regression task */
class QuantileMetric : public RegressionMetric<QuantileMetric> { class QuantileMetric : public RegressionMetric<QuantileMetric> {
public: public:
explicit QuantileMetric(const MetricConfig& config) :RegressionMetric<QuantileMetric>(config) { explicit QuantileMetric(const Config& config) :RegressionMetric<QuantileMetric>(config) {
} }
inline static double LossOnPoint(label_t label, double score, const MetricConfig& config) { inline static double LossOnPoint(label_t label, double score, const Config& config) {
double delta = label - score; double delta = label - score;
if (delta < 0) { if (delta < 0) {
return (config.alpha - 1.0f) * delta; return (config.alpha - 1.0f) * delta;
...@@ -161,9 +161,9 @@ public: ...@@ -161,9 +161,9 @@ public:
/*! \brief L1 loss for regression task */ /*! \brief L1 loss for regression task */
class L1Metric: public RegressionMetric<L1Metric> { class L1Metric: public RegressionMetric<L1Metric> {
public: public:
explicit L1Metric(const MetricConfig& config) :RegressionMetric<L1Metric>(config) {} explicit L1Metric(const Config& config) :RegressionMetric<L1Metric>(config) {}
inline static double LossOnPoint(label_t label, double score, const MetricConfig&) { inline static double LossOnPoint(label_t label, double score, const Config&) {
return std::fabs(score - label); return std::fabs(score - label);
} }
inline static const char* Name() { inline static const char* Name() {
...@@ -174,10 +174,10 @@ public: ...@@ -174,10 +174,10 @@ public:
/*! \brief Huber loss for regression task */ /*! \brief Huber loss for regression task */
class HuberLossMetric: public RegressionMetric<HuberLossMetric> { class HuberLossMetric: public RegressionMetric<HuberLossMetric> {
public: public:
explicit HuberLossMetric(const MetricConfig& config) :RegressionMetric<HuberLossMetric>(config) { explicit HuberLossMetric(const Config& config) :RegressionMetric<HuberLossMetric>(config) {
} }
inline static double LossOnPoint(label_t label, double score, const MetricConfig& config) { inline static double LossOnPoint(label_t label, double score, const Config& config) {
const double diff = score - label; const double diff = score - label;
if (std::abs(diff) <= config.alpha) { if (std::abs(diff) <= config.alpha) {
return 0.5f * diff * diff; return 0.5f * diff * diff;
...@@ -195,10 +195,10 @@ public: ...@@ -195,10 +195,10 @@ public:
// http://research.microsoft.com/en-us/um/people/zhang/INRIA/Publis/Tutorial-Estim/node24.html // http://research.microsoft.com/en-us/um/people/zhang/INRIA/Publis/Tutorial-Estim/node24.html
class FairLossMetric: public RegressionMetric<FairLossMetric> { class FairLossMetric: public RegressionMetric<FairLossMetric> {
public: public:
explicit FairLossMetric(const MetricConfig& config) :RegressionMetric<FairLossMetric>(config) { explicit FairLossMetric(const Config& config) :RegressionMetric<FairLossMetric>(config) {
} }
inline static double LossOnPoint(label_t label, double score, const MetricConfig& config) { inline static double LossOnPoint(label_t label, double score, const Config& config) {
const double x = std::fabs(score - label); const double x = std::fabs(score - label);
const double c = config.fair_c; const double c = config.fair_c;
return c * x - c * c * std::log(1.0f + x / c); return c * x - c * c * std::log(1.0f + x / c);
...@@ -212,10 +212,10 @@ public: ...@@ -212,10 +212,10 @@ public:
/*! \brief Poisson regression loss for regression task */ /*! \brief Poisson regression loss for regression task */
class PoissonMetric: public RegressionMetric<PoissonMetric> { class PoissonMetric: public RegressionMetric<PoissonMetric> {
public: public:
explicit PoissonMetric(const MetricConfig& config) :RegressionMetric<PoissonMetric>(config) { explicit PoissonMetric(const Config& config) :RegressionMetric<PoissonMetric>(config) {
} }
inline static double LossOnPoint(label_t label, double score, const MetricConfig&) { inline static double LossOnPoint(label_t label, double score, const Config&) {
const double eps = 1e-10f; const double eps = 1e-10f;
if (score < eps) { if (score < eps) {
score = eps; score = eps;
...@@ -231,10 +231,10 @@ public: ...@@ -231,10 +231,10 @@ public:
/*! \brief Mape regression loss for regression task */ /*! \brief Mape regression loss for regression task */
class MAPEMetric : public RegressionMetric<MAPEMetric> { class MAPEMetric : public RegressionMetric<MAPEMetric> {
public: public:
explicit MAPEMetric(const MetricConfig& config) :RegressionMetric<MAPEMetric>(config) { explicit MAPEMetric(const Config& config) :RegressionMetric<MAPEMetric>(config) {
} }
inline static double LossOnPoint(label_t label, double score, const MetricConfig&) { inline static double LossOnPoint(label_t label, double score, const Config&) {
return std::fabs((label - score)) / std::max(1.0f, std::fabs(label)); return std::fabs((label - score)) / std::max(1.0f, std::fabs(label));
} }
inline static const char* Name() { inline static const char* Name() {
...@@ -244,10 +244,10 @@ public: ...@@ -244,10 +244,10 @@ public:
class GammaMetric : public RegressionMetric<GammaMetric> { class GammaMetric : public RegressionMetric<GammaMetric> {
public: public:
explicit GammaMetric(const MetricConfig& config) :RegressionMetric<GammaMetric>(config) { explicit GammaMetric(const Config& config) :RegressionMetric<GammaMetric>(config) {
} }
inline static double LossOnPoint(label_t label, double score, const MetricConfig&) { inline static double LossOnPoint(label_t label, double score, const Config&) {
const double psi = 1.0; const double psi = 1.0;
const double theta = -1.0 / score; const double theta = -1.0 / score;
const double a = psi; const double a = psi;
...@@ -263,10 +263,10 @@ public: ...@@ -263,10 +263,10 @@ public:
class GammaDevianceMetric : public RegressionMetric<GammaDevianceMetric> { class GammaDevianceMetric : public RegressionMetric<GammaDevianceMetric> {
public: public:
explicit GammaDevianceMetric(const MetricConfig& config) :RegressionMetric<GammaDevianceMetric>(config) { explicit GammaDevianceMetric(const Config& config) :RegressionMetric<GammaDevianceMetric>(config) {
} }
inline static double LossOnPoint(label_t label, double score, const MetricConfig&) { inline static double LossOnPoint(label_t label, double score, const Config&) {
const double epsilon = 1.0e-9; const double epsilon = 1.0e-9;
const double tmp = label / (score + epsilon); const double tmp = label / (score + epsilon);
return tmp - std::log(tmp) - 1; return tmp - std::log(tmp) - 1;
...@@ -281,10 +281,10 @@ public: ...@@ -281,10 +281,10 @@ public:
class TweedieMetric : public RegressionMetric<TweedieMetric> { class TweedieMetric : public RegressionMetric<TweedieMetric> {
public: public:
explicit TweedieMetric(const MetricConfig& config) :RegressionMetric<TweedieMetric>(config) { explicit TweedieMetric(const Config& config) :RegressionMetric<TweedieMetric>(config) {
} }
inline static double LossOnPoint(label_t label, double score, const MetricConfig& config) { inline static double LossOnPoint(label_t label, double score, const Config& config) {
const double rho = config.tweedie_variance_power; const double rho = config.tweedie_variance_power;
const double a = label * std::exp((1 - rho) * std::log(score)) / (1 - rho); const double a = label * std::exp((1 - rho) * std::log(score)) / (1 - rho);
const double b = std::exp((2 - rho) * std::log(score)) / (2 - rho); const double b = std::exp((2 - rho) * std::log(score)) / (2 - rho);
......
...@@ -67,7 +67,7 @@ namespace LightGBM { ...@@ -67,7 +67,7 @@ namespace LightGBM {
// //
class CrossEntropyMetric : public Metric { class CrossEntropyMetric : public Metric {
public: public:
explicit CrossEntropyMetric(const MetricConfig&) {} explicit CrossEntropyMetric(const Config&) {}
virtual ~CrossEntropyMetric() {} virtual ~CrossEntropyMetric() {}
void Init(const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
...@@ -162,7 +162,7 @@ private: ...@@ -162,7 +162,7 @@ private:
// //
class CrossEntropyLambdaMetric : public Metric { class CrossEntropyLambdaMetric : public Metric {
public: public:
explicit CrossEntropyLambdaMetric(const MetricConfig&) {} explicit CrossEntropyLambdaMetric(const Config&) {}
virtual ~CrossEntropyLambdaMetric() {} virtual ~CrossEntropyLambdaMetric() {}
void Init(const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
...@@ -246,7 +246,7 @@ private: ...@@ -246,7 +246,7 @@ private:
// //
class KullbackLeiblerDivergence : public Metric { class KullbackLeiblerDivergence : public Metric {
public: public:
explicit KullbackLeiblerDivergence(const MetricConfig&) {} explicit KullbackLeiblerDivergence(const Config&) {}
virtual ~KullbackLeiblerDivergence() {} virtual ~KullbackLeiblerDivergence() {}
void Init(const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
......
...@@ -40,7 +40,7 @@ public: ...@@ -40,7 +40,7 @@ public:
* \brief Constructor * \brief Constructor
* \param config Config of network settings * \param config Config of network settings
*/ */
explicit Linkers(NetworkConfig config); explicit Linkers(Config config);
/*! /*!
* \brief Destructor * \brief Destructor
*/ */
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace LightGBM { namespace LightGBM {
Linkers::Linkers(NetworkConfig) { Linkers::Linkers(Config) {
is_init_ = false; is_init_ = false;
int argc = 0; int argc = 0;
char**argv = nullptr; char**argv = nullptr;
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace LightGBM { namespace LightGBM {
Linkers::Linkers(NetworkConfig config) { Linkers::Linkers(Config config) {
is_init_ = false; is_init_ = false;
// start up socket // start up socket
TcpSocket::Startup(); TcpSocket::Startup();
......
...@@ -23,7 +23,7 @@ THREAD_LOCAL ReduceScatterFunction Network::reduce_scatter_ext_fun_ = nullptr; ...@@ -23,7 +23,7 @@ THREAD_LOCAL ReduceScatterFunction Network::reduce_scatter_ext_fun_ = nullptr;
THREAD_LOCAL AllgatherFunction Network::allgather_ext_fun_ = nullptr; THREAD_LOCAL AllgatherFunction Network::allgather_ext_fun_ = nullptr;
void Network::Init(NetworkConfig config) { void Network::Init(Config config) {
if (config.num_machines > 1) { if (config.num_machines > 1) {
linkers_.reset(new Linkers(config)); linkers_.reset(new Linkers(config));
rank_ = linkers_->rank(); rank_ = linkers_->rank();
......
...@@ -12,7 +12,7 @@ namespace LightGBM { ...@@ -12,7 +12,7 @@ namespace LightGBM {
*/ */
class BinaryLogloss: public ObjectiveFunction { class BinaryLogloss: public ObjectiveFunction {
public: public:
explicit BinaryLogloss(const ObjectiveConfig& config, std::function<bool(label_t)> is_pos = nullptr) { explicit BinaryLogloss(const Config& config, std::function<bool(label_t)> is_pos = nullptr) {
sigmoid_ = static_cast<double>(config.sigmoid); sigmoid_ = static_cast<double>(config.sigmoid);
if (sigmoid_ <= 0.0) { if (sigmoid_ <= 0.0) {
Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_); Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment