"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "e17a28ec0c66336222bfbc335d9bce6faaa53e0e"
Commit dce329e5 authored by Hui Xue's avatar Hui Xue
Browse files

Merge remote-tracking branch 'upstream/master'

# Conflicts:
#	src/io/dataset.cpp
#	src/io/ordered_sparse_bin.hpp
#	src/treelearner/leaf_splits.hpp
#	src/treelearner/serial_tree_learner.cpp
parents 0b9fe27a a6a75fe9
...@@ -24,7 +24,7 @@ public: ...@@ -24,7 +24,7 @@ public:
*/ */
~GBDT(); ~GBDT();
/*! /*!
* \brief Initial logic * \brief Initialization logic
* \param config Config for boosting * \param config Config for boosting
* \param train_data Training data * \param train_data Training data
* \param object_function Training objective function * \param object_function Training objective function
...@@ -36,9 +36,9 @@ public: ...@@ -36,9 +36,9 @@ public:
const char* output_model_filename) const char* output_model_filename)
override; override;
/*! /*!
* \brief Add a validation data * \brief Adding a validation dataset
* \param valid_data Validation data * \param valid_data Validation dataset
* \param valid_metrics Metrics for validation data * \param valid_metrics Metrics for validation dataset
*/ */
void AddDataset(const Dataset* valid_data, void AddDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) override; const std::vector<const Metric*>& valid_metrics) override;
...@@ -47,18 +47,26 @@ public: ...@@ -47,18 +47,26 @@ public:
*/ */
void Train() override; void Train() override;
/*! /*!
* \brief Predtion for one record, not use sigmoid * \brief Predtion for one record without sigmoid transformation
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result for this record * \return Prediction result for this record
*/ */
double PredictRaw(const double * feature_values) const override; double PredictRaw(const double * feature_values) const override;
/*! /*!
* \brief Predtion for one record, will use sigmoid transform if needed * \brief Predtion for one record with sigmoid transformation if enabled
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result for this record * \return Prediction result for this record
*/ */
double Predict(const double * feature_values) const override; double Predict(const double * feature_values) const override;
/*!
* \brief Predtion for one record with leaf index
* \param feature_values Feature value on this record
* \return Predicted leaf index for this record
*/
std::vector<int> PredictLeafIndex(const double* value) const override;
/*! /*!
* \brief Serialize models by string * \brief Serialize models by string
* \return String output of tranined model * \return String output of tranined model
...@@ -87,8 +95,8 @@ private: ...@@ -87,8 +95,8 @@ private:
*/ */
void Bagging(int iter); void Bagging(int iter);
/*! /*!
* \brief update score for out-of-bag data. * \brief updating score for out-of-bag data.
* It is necessary for this update, since we may re-bagging data on training * Data should be update since we may re-bagging data on training
* \param tree Trained tree of this iteration * \param tree Trained tree of this iteration
*/ */
void UpdateScoreOutOfBag(const Tree* tree); void UpdateScoreOutOfBag(const Tree* tree);
...@@ -97,12 +105,12 @@ private: ...@@ -97,12 +105,12 @@ private:
*/ */
void Boosting(); void Boosting();
/*! /*!
* \brief train one tree * \brief training one tree
* \return Trained tree of this iteration * \return Trained tree of this iteration
*/ */
Tree* TrainOneTree(); Tree* TrainOneTree();
/*! /*!
* \brief update score after tree trained * \brief updating score after tree was trained
* \param tree Trained tree of this iteration * \param tree Trained tree of this iteration
*/ */
void UpdateScore(const Tree* tree); void UpdateScore(const Tree* tree);
...@@ -110,8 +118,10 @@ private: ...@@ -110,8 +118,10 @@ private:
* \brief Print metric result of current iteration * \brief Print metric result of current iteration
* \param iter Current interation * \param iter Current interation
*/ */
void OutputMetric(int iter); bool OutputMetric(int iter);
int early_stopping_round_;
/*! \brief Pointer to training data */ /*! \brief Pointer to training data */
const Dataset* train_data_; const Dataset* train_data_;
/*! \brief Config of gbdt */ /*! \brief Config of gbdt */
...@@ -128,6 +138,9 @@ private: ...@@ -128,6 +138,9 @@ private:
std::vector<ScoreUpdater*> valid_score_updater_; std::vector<ScoreUpdater*> valid_score_updater_;
/*! \brief Metric for validation data */ /*! \brief Metric for validation data */
std::vector<std::vector<const Metric*>> valid_metrics_; std::vector<std::vector<const Metric*>> valid_metrics_;
/*! \brief Best score(s) for early stopping */
std::vector<std::vector<int>> best_iter_;
std::vector<std::vector<score_t>> best_score_;
/*! \brief Trained models(trees) */ /*! \brief Trained models(trees) */
std::vector<Tree*> models_; std::vector<Tree*> models_;
/*! \brief Max feature index of training data*/ /*! \brief Max feature index of training data*/
......
...@@ -37,25 +37,25 @@ public: ...@@ -37,25 +37,25 @@ public:
delete[] score_; delete[] score_;
} }
/*! /*!
* \brief Use tree model to get prediction, then add to score for all data * \brief Using tree model to get prediction number, then adding to scores for all data
* Note: this function generally will be used for validation data. * Note: this function generally will be used on validation data too.
* \param tree Trained tree model * \param tree Trained tree model
*/ */
inline void AddScore(const Tree* tree) { inline void AddScore(const Tree* tree) {
tree->AddPredictionToScore(data_, num_data_, score_); tree->AddPredictionToScore(data_, num_data_, score_);
} }
/*! /*!
* \brief Add prediction score, only used for training data. * \brief Adding prediction score, only used for training data.
* After trained a tree, the training data is partitioned into tree leaves. * The training data is partitioned into tree leaves after training
* We can get prediction by faster speed based on this. * Based on which We can get prediction quckily.
* \param tree_learner * \param tree_learner
*/ */
inline void AddScore(const TreeLearner* tree_learner) { inline void AddScore(const TreeLearner* tree_learner) {
tree_learner->AddPredictionToScore(score_); tree_learner->AddPredictionToScore(score_);
} }
/*! /*!
* \brief Like AddScore(const Tree* tree), but only for part of data * \brief Using tree model to get prediction number, then adding to scores for parts of data
* Used for prediction of training out-of-bad data * Used for prediction of training out-of-bag data
* \param tree Trained tree model * \param tree Trained tree model
* \param data_indices Indices of data that will be proccessed * \param data_indices Indices of data that will be proccessed
* \param data_cnt Number of data that will be proccessed * \param data_cnt Number of data that will be proccessed
......
...@@ -182,35 +182,35 @@ template class OrderedSparseBin<uint16_t>; ...@@ -182,35 +182,35 @@ template class OrderedSparseBin<uint16_t>;
template class OrderedSparseBin<uint32_t>; template class OrderedSparseBin<uint32_t>;
Bin* Bin::CreateBin(data_size_t num_data, int num_bin, double sparse_rate, bool is_enable_sparse, bool* is_sparse) { Bin* Bin::CreateBin(data_size_t num_data, int num_bin, double sparse_rate, bool is_enable_sparse, bool* is_sparse, int default_bin) {
// sparse threshold // sparse threshold
const double kSparseThreshold = 0.8; const double kSparseThreshold = 0.8;
if (sparse_rate >= kSparseThreshold && is_enable_sparse) { if (sparse_rate >= kSparseThreshold && is_enable_sparse) {
*is_sparse = true; *is_sparse = true;
return CreateSparseBin(num_data, num_bin); return CreateSparseBin(num_data, num_bin, default_bin);
} else { } else {
*is_sparse = false; *is_sparse = false;
return CreateDenseBin(num_data, num_bin); return CreateDenseBin(num_data, num_bin, default_bin);
} }
} }
Bin* Bin::CreateDenseBin(data_size_t num_data, int num_bin) { Bin* Bin::CreateDenseBin(data_size_t num_data, int num_bin, int default_bin) {
if (num_bin <= 256) { if (num_bin <= 256) {
return new DenseBin<uint8_t>(num_data); return new DenseBin<uint8_t>(num_data, default_bin);
} else if (num_bin <= 65536) { } else if (num_bin <= 65536) {
return new DenseBin<uint16_t>(num_data); return new DenseBin<uint16_t>(num_data, default_bin);
} else { } else {
return new DenseBin<uint32_t>(num_data); return new DenseBin<uint32_t>(num_data, default_bin);
} }
} }
Bin* Bin::CreateSparseBin(data_size_t num_data, int num_bin) { Bin* Bin::CreateSparseBin(data_size_t num_data, int num_bin, int default_bin) {
if (num_bin <= 256) { if (num_bin <= 256) {
return new SparseBin<uint8_t>(num_data); return new SparseBin<uint8_t>(num_data, default_bin);
} else if (num_bin <= 65536) { } else if (num_bin <= 65536) {
return new SparseBin<uint16_t>(num_data); return new SparseBin<uint16_t>(num_data, default_bin);
} else { } else {
return new SparseBin<uint32_t>(num_data); return new SparseBin<uint32_t>(num_data, default_bin);
} }
} }
......
...@@ -14,6 +14,8 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para ...@@ -14,6 +14,8 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
// load main config types // load main config types
GetInt(params, "num_threads", &num_threads); GetInt(params, "num_threads", &num_threads);
GetTaskType(params); GetTaskType(params);
GetBool(params, "predict_leaf_index", &predict_leaf_index);
GetBoostingType(params); GetBoostingType(params);
GetObjectiveType(params); GetObjectiveType(params);
...@@ -34,6 +36,19 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para ...@@ -34,6 +36,19 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
metric_config.Set(params); metric_config.Set(params);
// check for conflicts // check for conflicts
CheckParamConflict(); CheckParamConflict();
if (io_config.verbosity == 1) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
}
else if (io_config.verbosity == 0) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Error);
}
else if (io_config.verbosity >= 2) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
}
else {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
}
} }
void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::string>& params) { void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::string>& params) {
...@@ -43,7 +58,7 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s ...@@ -43,7 +58,7 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s
if (value == std::string("gbdt") || value == std::string("gbrt")) { if (value == std::string("gbdt") || value == std::string("gbrt")) {
boosting_type = BoostingType::kGBDT; boosting_type = BoostingType::kGBDT;
} else { } else {
Log::Stderr("boosting type %s error", value.c_str()); Log::Fatal("Boosting type %s error", value.c_str());
} }
} }
} }
...@@ -91,34 +106,37 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin ...@@ -91,34 +106,37 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
|| value == std::string("test")) { || value == std::string("test")) {
task_type = TaskType::kPredict; task_type = TaskType::kPredict;
} else { } else {
Log::Stderr("task type error"); Log::Fatal("Task type error");
} }
} }
} }
void OverallConfig::CheckParamConflict() { void OverallConfig::CheckParamConflict() {
GBDTConfig* gbdt_config = dynamic_cast<GBDTConfig*>(boosting_config);
if (network_config.num_machines > 1) { if (network_config.num_machines > 1) {
is_parallel = true; is_parallel = true;
} else { } else {
is_parallel = false; is_parallel = false;
dynamic_cast<GBDTConfig*>(boosting_config)->tree_learner_type = gbdt_config->tree_learner_type = TreeLearnerType::kSerialTreeLearner;
TreeLearnerType::kSerialTreeLearner;
} }
if (dynamic_cast<GBDTConfig*>(boosting_config)->tree_learner_type == if (gbdt_config->tree_learner_type == TreeLearnerType::kSerialTreeLearner) {
TreeLearnerType::kSerialTreeLearner) {
is_parallel = false; is_parallel = false;
network_config.num_machines = 1; network_config.num_machines = 1;
} }
if (dynamic_cast<GBDTConfig*>(boosting_config)->tree_learner_type == if (gbdt_config->tree_learner_type == TreeLearnerType::kSerialTreeLearner ||
TreeLearnerType::kSerialTreeLearner || gbdt_config->tree_learner_type == TreeLearnerType::kFeatureParallelTreelearner) {
dynamic_cast<GBDTConfig*>(boosting_config)->tree_learner_type ==
TreeLearnerType::kFeatureParallelTreelearner) {
is_parallel_find_bin = false; is_parallel_find_bin = false;
} else if (dynamic_cast<GBDTConfig*>(boosting_config)->tree_learner_type == } else if (gbdt_config->tree_learner_type == TreeLearnerType::kDataParallelTreeLearner) {
TreeLearnerType::kDataParallelTreeLearner) {
is_parallel_find_bin = true; is_parallel_find_bin = true;
if (gbdt_config->tree_config.histogram_pool_size >= 0) {
Log::Error("Histogram LRU queue was enabled (histogram_pool_size=%f). Will disable this for reducing communication cost."
, gbdt_config->tree_config.histogram_pool_size);
// Change pool size to -1(not limit) when using data parallel for reducing communication cost
gbdt_config->tree_config.histogram_pool_size = -1;
}
} }
} }
...@@ -128,8 +146,9 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -128,8 +146,9 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "data_random_seed", &data_random_seed); GetInt(params, "data_random_seed", &data_random_seed);
if (!GetString(params, "data", &data_filename)) { if (!GetString(params, "data", &data_filename)) {
Log::Stderr("No training/prediction data, application quit"); Log::Fatal("No training/prediction data, application quit");
} }
GetInt(params, "verbose", &verbosity);
GetInt(params, "num_model_predict", &num_model_predict); GetInt(params, "num_model_predict", &num_model_predict);
GetBool(params, "is_pre_partition", &is_pre_partition); GetBool(params, "is_pre_partition", &is_pre_partition);
GetBool(params, "is_enable_sparse", &is_enable_sparse); GetBool(params, "is_enable_sparse", &is_enable_sparse);
...@@ -140,6 +159,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -140,6 +159,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetString(params, "input_model", &input_model); GetString(params, "input_model", &input_model);
GetString(params, "output_result", &output_result); GetString(params, "output_result", &output_result);
GetString(params, "input_init_score", &input_init_score); GetString(params, "input_init_score", &input_init_score);
GetString(params, "log_file", &log_file);
std::string tmp_str = ""; std::string tmp_str = "";
if (GetString(params, "valid_data", &tmp_str)) { if (GetString(params, "valid_data", &tmp_str)) {
valid_data_filenames = Common::Split(tmp_str.c_str(), ','); valid_data_filenames = Common::Split(tmp_str.c_str(), ',');
...@@ -167,6 +187,7 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa ...@@ -167,6 +187,7 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa
void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) { void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "early_stopping_round", &early_stopping_round);
GetInt(params, "metric_freq", &output_freq); GetInt(params, "metric_freq", &output_freq);
CHECK(output_freq >= 0); CHECK(output_freq >= 0);
GetDouble(params, "sigmoid", &sigmoid); GetDouble(params, "sigmoid", &sigmoid);
...@@ -202,10 +223,13 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) ...@@ -202,10 +223,13 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf); GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
CHECK(min_sum_hessian_in_leaf > 1.0f || min_data_in_leaf > 0); CHECK(min_sum_hessian_in_leaf > 1.0f || min_data_in_leaf > 0);
GetInt(params, "num_leaves", &num_leaves); GetInt(params, "num_leaves", &num_leaves);
CHECK(num_leaves > 0); CHECK(num_leaves > 1);
GetInt(params, "feature_fraction_seed", &feature_fraction_seed); GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
GetDouble(params, "feature_fraction", &feature_fraction); GetDouble(params, "feature_fraction", &feature_fraction);
CHECK(feature_fraction > 0.0 && feature_fraction <= 1.0); CHECK(feature_fraction > 0.0 && feature_fraction <= 1.0);
GetDouble(params, "histogram_pool_size", &histogram_pool_size);
GetInt(params, "max_depth", &max_depth);
CHECK(max_depth > 1 || max_depth < 0);
} }
...@@ -219,6 +243,8 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par ...@@ -219,6 +243,8 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
CHECK(bagging_fraction > 0.0 && bagging_fraction <= 1.0); CHECK(bagging_fraction > 0.0 && bagging_fraction <= 1.0);
GetDouble(params, "learning_rate", &learning_rate); GetDouble(params, "learning_rate", &learning_rate);
CHECK(learning_rate > 0.0); CHECK(learning_rate > 0.0);
GetInt(params, "early_stopping_round", &early_stopping_round);
CHECK(early_stopping_round >= 0);
} }
void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) { void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
...@@ -233,7 +259,7 @@ void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::s ...@@ -233,7 +259,7 @@ void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::s
tree_learner_type = TreeLearnerType::kDataParallelTreeLearner; tree_learner_type = TreeLearnerType::kDataParallelTreeLearner;
} }
else { else {
Log::Stderr("tree learner type error"); Log::Fatal("Tree learner type error");
} }
} }
} }
......
...@@ -21,7 +21,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -21,7 +21,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
CheckCanLoadFromBin(); CheckCanLoadFromBin();
if (is_loading_from_binfile_ && predict_fun != nullptr) { if (is_loading_from_binfile_ && predict_fun != nullptr) {
Log::Stdout("cannot perform initial prediction for binary file, will use text file instead"); Log::Info("Cannot performing initialization of prediction by using binary file, using text file instead");
is_loading_from_binfile_ = false; is_loading_from_binfile_ = false;
} }
...@@ -31,14 +31,14 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -31,14 +31,14 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
// create text parser // create text parser
parser_ = Parser::CreateParser(data_filename_, 0, nullptr); parser_ = Parser::CreateParser(data_filename_, 0, nullptr);
if (parser_ == nullptr) { if (parser_ == nullptr) {
Log::Stderr("cannot recognize input data format, filename: %s", data_filename_); Log::Fatal("Cannot recognising input data format, filename: %s", data_filename_);
} }
// create text reader // create text reader
text_reader_ = new TextReader<data_size_t>(data_filename); text_reader_ = new TextReader<data_size_t>(data_filename);
} else { } else {
// only need to load initilize score, other meta data will be loaded from bin flie // only need to load initilize score, other meta data will be loaded from bin flie
metadata_.Init(init_score_filename); metadata_.Init(init_score_filename);
Log::Stdout("will load data set from binary file"); Log::Info("Loading data set from binary file");
parser_ = nullptr; parser_ = nullptr;
text_reader_ = nullptr; text_reader_ = nullptr;
} }
...@@ -82,7 +82,7 @@ void Dataset::LoadDataToMemory(int rank, int num_machines, bool is_pre_partition ...@@ -82,7 +82,7 @@ void Dataset::LoadDataToMemory(int rank, int num_machines, bool is_pre_partition
[this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries] [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
(data_size_t line_idx) { (data_size_t line_idx) {
if (qid >= num_queries) { if (qid >= num_queries) {
Log::Stderr("current query is exceed the range of query file, please ensure your query file is correct"); Log::Fatal("Current query is exceed the range of query file, please ensure your query file is correct");
} }
if (line_idx >= query_boundaries[qid + 1]) { if (line_idx >= query_boundaries[qid + 1]) {
// if is new query // if is new query
...@@ -139,7 +139,7 @@ void Dataset::SampleDataFromFile(int rank, int num_machines, bool is_pre_partiti ...@@ -139,7 +139,7 @@ void Dataset::SampleDataFromFile(int rank, int num_machines, bool is_pre_partiti
[this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries] [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
(data_size_t line_idx) { (data_size_t line_idx) {
if (qid >= num_queries) { if (qid >= num_queries) {
Log::Stderr("current query is exceed the range of query file, \ Log::Fatal("Query id is exceed the range of query file, \
please ensure your query file is correct"); please ensure your query file is correct");
} }
if (line_idx >= query_boundaries[qid + 1]) { if (line_idx >= query_boundaries[qid + 1]) {
...@@ -189,7 +189,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -189,7 +189,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
// -1 means doesn't use this feature // -1 means doesn't use this feature
used_feature_map_ = std::vector<int>(sample_values.size(), -1); used_feature_map_ = std::vector<int>(sample_values.size(), -1);
num_total_features_ = sample_values.size(); num_total_features_ = static_cast<int>(sample_values.size());
// start find bins // start find bins
if (num_machines == 1) { if (num_machines == 1) {
std::vector<BinMapper*> bin_mappers(sample_values.size()); std::vector<BinMapper*> bin_mappers(sample_values.size());
...@@ -209,7 +209,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -209,7 +209,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
num_data_, is_enable_sparse_)); num_data_, is_enable_sparse_));
} else { } else {
// if feature is trival(only 1 bin), free spaces // if feature is trival(only 1 bin), free spaces
Log::Stdout("Warning: feture %d only contains one value, will ignore it", i); Log::Error("Feature %d only contains one value, will be ignored", i);
delete bin_mappers[i]; delete bin_mappers[i];
} }
} }
...@@ -486,10 +486,10 @@ void Dataset::SaveBinaryFile() { ...@@ -486,10 +486,10 @@ void Dataset::SaveBinaryFile() {
file = fopen(bin_filename.c_str(), "wb"); file = fopen(bin_filename.c_str(), "wb");
#endif #endif
if (file == NULL) { if (file == NULL) {
Log::Stderr("cannot write binary data to %s ", bin_filename.c_str()); Log::Fatal("Cannot write binary data to %s ", bin_filename.c_str());
} }
Log::Stdout("start save binary file for data %s", data_filename_); Log::Info("Saving data to binary file: %s", data_filename_);
// get size of header // get size of header
size_t size_of_header = sizeof(global_num_data_) + sizeof(is_enable_sparse_) size_t size_of_header = sizeof(global_num_data_) + sizeof(is_enable_sparse_)
...@@ -556,7 +556,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -556,7 +556,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
#endif #endif
if (file == NULL) { if (file == NULL) {
Log::Stderr("cannot read binary data from %s", bin_filename.c_str()); Log::Fatal("Cannot read binary data from %s", bin_filename.c_str());
} }
// buffer to read binary file // buffer to read binary file
...@@ -567,7 +567,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -567,7 +567,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
size_t read_cnt = fread(buffer, sizeof(size_t), 1, file); size_t read_cnt = fread(buffer, sizeof(size_t), 1, file);
if (read_cnt != 1) { if (read_cnt != 1) {
Log::Stderr("binary file format error at header size"); Log::Fatal("Binary file format error at header size");
} }
size_t size_of_head = *(reinterpret_cast<size_t*>(buffer)); size_t size_of_head = *(reinterpret_cast<size_t*>(buffer));
...@@ -582,7 +582,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -582,7 +582,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
read_cnt = fread(buffer, 1, size_of_head, file); read_cnt = fread(buffer, 1, size_of_head, file);
if (read_cnt != size_of_head) { if (read_cnt != size_of_head) {
Log::Stderr("binary file format error at header"); Log::Fatal("Binary file format error at header");
} }
// get header // get header
const char* mem_ptr = buffer; const char* mem_ptr = buffer;
...@@ -608,7 +608,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -608,7 +608,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
read_cnt = fread(buffer, sizeof(size_t), 1, file); read_cnt = fread(buffer, sizeof(size_t), 1, file);
if (read_cnt != 1) { if (read_cnt != 1) {
Log::Stderr("binary file format error at size of meta data"); Log::Fatal("Binary file format error: wrong size of meta data");
} }
size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer)); size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer));
...@@ -623,7 +623,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -623,7 +623,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
read_cnt = fread(buffer, 1, size_of_metadata, file); read_cnt = fread(buffer, 1, size_of_metadata, file);
if (read_cnt != size_of_metadata) { if (read_cnt != size_of_metadata) {
Log::Stderr("binary file format error at meta data"); Log::Fatal("Binary file format error: wrong size of meta data");
} }
// load meta data // load meta data
metadata_.LoadFromMemory(buffer); metadata_.LoadFromMemory(buffer);
...@@ -647,7 +647,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -647,7 +647,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
bool is_query_used = false; bool is_query_used = false;
for (data_size_t i = 0; i < num_data_; i++) { for (data_size_t i = 0; i < num_data_; i++) {
if (qid >= num_queries) { if (qid >= num_queries) {
Log::Stderr("current query is exceed the range of query file, please ensure your query file is correct"); Log::Fatal("current query is exceed the range of query file, please ensure your query file is correct");
} }
if (i >= query_boundaries[qid + 1]) { if (i >= query_boundaries[qid + 1]) {
// if is new query // if is new query
...@@ -670,7 +670,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -670,7 +670,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
// read feature size // read feature size
read_cnt = fread(buffer, sizeof(size_t), 1, file); read_cnt = fread(buffer, sizeof(size_t), 1, file);
if (read_cnt != 1) { if (read_cnt != 1) {
Log::Stderr("binary file format error at feature %d's size", i); Log::Fatal("Binary file format error at feature %d's size", i);
} }
size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer)); size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer));
// re-allocate space if not enough // re-allocate space if not enough
...@@ -683,7 +683,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -683,7 +683,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
read_cnt = fread(buffer, 1, size_of_feature, file); read_cnt = fread(buffer, 1, size_of_feature, file);
if (read_cnt != size_of_feature) { if (read_cnt != size_of_feature) {
Log::Stderr("binary file format error at feature %d loading , read count %d", i, read_cnt); Log::Fatal("Binary file format error at feature %d loading , read count %d", i, read_cnt);
} }
features_.push_back(new Feature(buffer, static_cast<data_size_t>(global_num_data_), used_data_indices_)); features_.push_back(new Feature(buffer, static_cast<data_size_t>(global_num_data_), used_data_indices_));
} }
...@@ -693,10 +693,10 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -693,10 +693,10 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
void Dataset::CheckDataset() { void Dataset::CheckDataset() {
if (num_data_ <= 0) { if (num_data_ <= 0) {
Log::Stderr("data size of %s is zero", data_filename_); Log::Fatal("Data file %s is empty", data_filename_);
} }
if (features_.size() <= 0) { if (features_.size() <= 0) {
Log::Stderr("not useful feature of data %s", data_filename_); Log::Fatal("Usable feature of data %s is null", data_filename_);
} }
} }
......
...@@ -16,10 +16,17 @@ namespace LightGBM { ...@@ -16,10 +16,17 @@ namespace LightGBM {
template <typename VAL_T> template <typename VAL_T>
class DenseBin: public Bin { class DenseBin: public Bin {
public: public:
explicit DenseBin(data_size_t num_data) explicit DenseBin(data_size_t num_data, int default_bin)
: num_data_(num_data) { : num_data_(num_data) {
data_ = new VAL_T[num_data_]; data_ = new VAL_T[num_data_];
std::memset(data_, 0, sizeof(VAL_T)*num_data_); if (default_bin == 0) {
std::memset(data_, 0, sizeof(VAL_T)*num_data_);
} else {
VAL_T default_bin_T = static_cast<VAL_T>(default_bin);
for (data_size_t i = 0; i < num_data_; ++i) {
data_[i] = default_bin_T;
}
}
} }
~DenseBin() { ~DenseBin() {
......
...@@ -61,7 +61,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -61,7 +61,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
if (used_data_indices.size() == 0) { if (used_data_indices.size() == 0) {
// check weights // check weights
if (weights_ != nullptr && num_weights_ != num_data_) { if (weights_ != nullptr && num_weights_ != num_data_) {
Log::Stdout("init weight size doesn't equal with data file, will ignore"); Log::Error("Initial weight size doesn't equal to data, weights will be ignored");
delete[] weights_; delete[] weights_;
num_weights_ = 0; num_weights_ = 0;
weights_ = nullptr; weights_ = nullptr;
...@@ -69,7 +69,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -69,7 +69,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// check query boundries // check query boundries
if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_data_) { if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_data_) {
Log::Stdout("init query size doesn't equal with data file, will ignore"); Log::Error("Initial query size doesn't equal to data, queies will be ignored");
delete[] query_boundaries_; delete[] query_boundaries_;
num_queries_ = 0; num_queries_ = 0;
query_boundaries_ = nullptr; query_boundaries_ = nullptr;
...@@ -78,21 +78,21 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -78,21 +78,21 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// contain initial score file // contain initial score file
if (init_score_ != nullptr && num_init_score_ != num_data_) { if (init_score_ != nullptr && num_init_score_ != num_data_) {
delete[] init_score_; delete[] init_score_;
Log::Stdout("init score size doesn't equal with data file, will ignore"); Log::Error("Initial score size doesn't equal to data, score file will be ignored");
num_init_score_ = 0; num_init_score_ = 0;
} }
} else { } else {
data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size()); data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
// check weights // check weights
if (weights_ != nullptr && num_weights_ != num_all_data) { if (weights_ != nullptr && num_weights_ != num_all_data) {
Log::Stdout("init weight size doesn't equal with data file, will ignore"); Log::Error("Initial weights size doesn't equal to data, weights will be ignored");
delete[] weights_; delete[] weights_;
num_weights_ = 0; num_weights_ = 0;
weights_ = nullptr; weights_ = nullptr;
} }
// check query boundries // check query boundries
if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_all_data) { if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_all_data) {
Log::Stdout("init query size doesn't equal with data file, will ignore"); Log::Error("Initial query size doesn't equal to data , queries will be ignored");
delete[] query_boundaries_; delete[] query_boundaries_;
num_queries_ = 0; num_queries_ = 0;
query_boundaries_ = nullptr; query_boundaries_ = nullptr;
...@@ -100,7 +100,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -100,7 +100,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// contain initial score file // contain initial score file
if (init_score_ != nullptr && num_init_score_ != num_all_data) { if (init_score_ != nullptr && num_init_score_ != num_all_data) {
Log::Stdout("init score size doesn't equal with data file, will ignore"); Log::Error("Initial score size doesn't equal to data , initial scores will be ignored");
delete[] init_score_; delete[] init_score_;
num_init_score_ = 0; num_init_score_ = 0;
} }
...@@ -131,10 +131,10 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -131,10 +131,10 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
used_query.push_back(qid); used_query.push_back(qid);
data_idx += len; data_idx += len;
} else { } else {
Log::Stderr("data partition error, not according to query"); Log::Fatal("Data partition error, data didn't match queies");
} }
} else { } else {
Log::Stderr("data partition error, not according to query"); Log::Fatal("Data partition error, data didn't match queies");
} }
} }
data_size_t * old_query_boundaries = query_boundaries_; data_size_t * old_query_boundaries = query_boundaries_;
...@@ -182,7 +182,7 @@ void Metadata::LoadWeights() { ...@@ -182,7 +182,7 @@ void Metadata::LoadWeights() {
if (reader.Lines().size() <= 0) { if (reader.Lines().size() <= 0) {
return; return;
} }
Log::Stdout("Start to load weights"); Log::Info("Start loading weights");
num_weights_ = static_cast<data_size_t>(reader.Lines().size()); num_weights_ = static_cast<data_size_t>(reader.Lines().size());
weights_ = new float[num_weights_]; weights_ = new float[num_weights_];
for (data_size_t i = 0; i < num_weights_; ++i) { for (data_size_t i = 0; i < num_weights_; ++i) {
...@@ -198,7 +198,7 @@ void Metadata::LoadInitialScore() { ...@@ -198,7 +198,7 @@ void Metadata::LoadInitialScore() {
TextReader<size_t> reader(init_score_filename_); TextReader<size_t> reader(init_score_filename_);
reader.ReadAllLines(); reader.ReadAllLines();
Log::Stdout("Start to load initial score"); Log::Info("Start loading initial scores");
num_init_score_ = static_cast<data_size_t>(reader.Lines().size()); num_init_score_ = static_cast<data_size_t>(reader.Lines().size());
init_score_ = new score_t[num_init_score_]; init_score_ = new score_t[num_init_score_];
double tmp = 0.0f; double tmp = 0.0f;
...@@ -218,7 +218,7 @@ void Metadata::LoadQueryBoundaries() { ...@@ -218,7 +218,7 @@ void Metadata::LoadQueryBoundaries() {
if (reader.Lines().size() <= 0) { if (reader.Lines().size() <= 0) {
return; return;
} }
Log::Stdout("Start to load query boundries"); Log::Info("Start loading query boundries");
query_boundaries_ = new data_size_t[reader.Lines().size() + 1]; query_boundaries_ = new data_size_t[reader.Lines().size() + 1];
num_queries_ = static_cast<data_size_t>(reader.Lines().size()); num_queries_ = static_cast<data_size_t>(reader.Lines().size());
query_boundaries_[0] = 0; query_boundaries_[0] = 0;
...@@ -233,7 +233,7 @@ void Metadata::LoadQueryWeights() { ...@@ -233,7 +233,7 @@ void Metadata::LoadQueryWeights() {
if (weights_ == nullptr || query_boundaries_ == nullptr) { if (weights_ == nullptr || query_boundaries_ == nullptr) {
return; return;
} }
Log::Stdout("Start to load query weights"); Log::Info("Start loading query weights");
query_weights_ = new float[num_queries_]; query_weights_ = new float[num_queries_];
for (data_size_t i = 0; i < num_queries_; ++i) { for (data_size_t i = 0; i < num_queries_; ++i) {
query_weights_[i] = 0.0f; query_weights_[i] = 0.0f;
......
...@@ -13,12 +13,21 @@ ...@@ -13,12 +13,21 @@
namespace LightGBM { namespace LightGBM {
/*! /*!
<<<<<<< HEAD
* \brief Ordered bin for sparse feature . Efficient for construct histogram, especally for sparse bin * \brief Ordered bin for sparse feature . Efficient for construct histogram, especally for sparse bin
* There are 2 advantages for using ordered bin. * There are 2 advantages for using ordered bin.
* 1. group the data by leaf, improve the cache hit. * 1. group the data by leaf, improve the cache hit.
* 2. only store the non-zero bin, which can speed up the histogram cconsturction for sparse feature. * 2. only store the non-zero bin, which can speed up the histogram cconsturction for sparse feature.
* But it has a additional cost, it need re-order the bins after leaf split, which will cost much for dense feature. * But it has a additional cost, it need re-order the bins after leaf split, which will cost much for dense feature.
* So we only use ordered bin for sparse features now. * So we only use ordered bin for sparse features now.
=======
* \brief Interface for ordered bin data. efficient for construct histogram, especially for sparse bin
* There are 2 advantages by using ordered bin.
* 1. group the data by leafs to improve the cache hit.
* 2. only store the non-zero bin, which can speed up the histogram consturction for sparse features.
* However it brings additional cost: it need re-order the bins after every split, which will cost much for dense feature.
* So we only using ordered bin for sparse situations.
>>>>>>> upstream/master
*/ */
template <typename VAL_T> template <typename VAL_T>
class OrderedSparseBin:public OrderedBin { class OrderedSparseBin:public OrderedBin {
......
...@@ -34,7 +34,7 @@ bool CheckHasLabelForLibsvm(std::string& str) { ...@@ -34,7 +34,7 @@ bool CheckHasLabelForLibsvm(std::string& str) {
bool CheckHasLabelForTSV(std::string& str, int num_features) { bool CheckHasLabelForTSV(std::string& str, int num_features) {
str = Common::Trim(str); str = Common::Trim(str);
auto tokens = Common::Split(str.c_str(), '\t'); auto tokens = Common::Split(str.c_str(), '\t');
if (tokens.size() == num_features) { if (static_cast<int>(tokens.size()) == num_features) {
return false; return false;
} else { } else {
return true; return true;
...@@ -44,7 +44,7 @@ bool CheckHasLabelForTSV(std::string& str, int num_features) { ...@@ -44,7 +44,7 @@ bool CheckHasLabelForTSV(std::string& str, int num_features) {
bool CheckHasLabelForCSV(std::string& str, int num_features) { bool CheckHasLabelForCSV(std::string& str, int num_features) {
str = Common::Trim(str); str = Common::Trim(str);
auto tokens = Common::Split(str.c_str(), ','); auto tokens = Common::Split(str.c_str(), ',');
if (tokens.size() == num_features) { if (static_cast<int>(tokens.size()) == num_features) {
return false; return false;
} else { } else {
return true; return true;
...@@ -55,18 +55,18 @@ Parser* Parser::CreateParser(const char* filename, int num_features, bool* has_l ...@@ -55,18 +55,18 @@ Parser* Parser::CreateParser(const char* filename, int num_features, bool* has_l
std::ifstream tmp_file; std::ifstream tmp_file;
tmp_file.open(filename); tmp_file.open(filename);
if (!tmp_file.is_open()) { if (!tmp_file.is_open()) {
Log::Stderr("Data file: %s doesn't exist", filename); Log::Fatal("Data file: %s doesn't exist", filename);
} }
std::string line1, line2; std::string line1, line2;
if (!tmp_file.eof()) { if (!tmp_file.eof()) {
std::getline(tmp_file, line1); std::getline(tmp_file, line1);
} else { } else {
Log::Stderr("Data file: %s at least should have one line", filename); Log::Fatal("Data file: %s at least should have one line", filename);
} }
if (!tmp_file.eof()) { if (!tmp_file.eof()) {
std::getline(tmp_file, line2); std::getline(tmp_file, line2);
} else { } else {
Log::Stdout("Data file: %s only have one line", filename); Log::Error("Data file: %s only have one line", filename);
} }
tmp_file.close(); tmp_file.close();
int comma_cnt = 0, comma_cnt2 = 0; int comma_cnt = 0, comma_cnt2 = 0;
......
...@@ -20,12 +20,14 @@ public: ...@@ -20,12 +20,14 @@ public:
double val = 0.0; double val = 0.0;
while (*str != '\0') { while (*str != '\0') {
str = Common::Atof(str, &val); str = Common::Atof(str, &val);
out_features->emplace_back(idx, val); if (fabs(val) > 1e-10) {
out_features->emplace_back(idx, val);
}
++idx; ++idx;
if (*str == ',') { if (*str == ',') {
++str; ++str;
} else if (*str != '\0') { } else if (*str != '\0') {
Log::Stderr("input format error, should be CSV"); Log::Fatal("input format error, should be CSV");
} }
} }
} }
...@@ -36,7 +38,7 @@ public: ...@@ -36,7 +38,7 @@ public:
if (*str == ',') { if (*str == ',') {
++str; ++str;
} else if (*str != '\0') { } else if (*str != '\0') {
Log::Stderr("input format error, should be CSV"); Log::Fatal("input format error, should be CSV");
} }
return ParseOneLine(str, out_features); return ParseOneLine(str, out_features);
} }
...@@ -49,12 +51,14 @@ public: ...@@ -49,12 +51,14 @@ public:
double val = 0.0; double val = 0.0;
while (*str != '\0') { while (*str != '\0') {
str = Common::Atof(str, &val); str = Common::Atof(str, &val);
out_features->emplace_back(idx, val); if (fabs(val) > 1e-10) {
out_features->emplace_back(idx, val);
}
++idx; ++idx;
if (*str == '\t') { if (*str == '\t') {
++str; ++str;
} else if (*str != '\0') { } else if (*str != '\0') {
Log::Stderr("input format error, should be TSV"); Log::Fatal("input format error, should be TSV");
} }
} }
} }
...@@ -65,7 +69,7 @@ public: ...@@ -65,7 +69,7 @@ public:
if (*str == '\t') { if (*str == '\t') {
++str; ++str;
} else if (*str != '\0') { } else if (*str != '\0') {
Log::Stderr("input format error, should be TSV"); Log::Fatal("input format error, should be TSV");
} }
return ParseOneLine(str, out_features); return ParseOneLine(str, out_features);
} }
...@@ -84,7 +88,7 @@ public: ...@@ -84,7 +88,7 @@ public:
str = Common::Atof(str, &val); str = Common::Atof(str, &val);
out_features->emplace_back(idx, val); out_features->emplace_back(idx, val);
} else { } else {
Log::Stderr("input format error, should be LibSVM"); Log::Fatal("input format error, should be LibSVM");
} }
str = Common::SkipSpaceAndTab(str); str = Common::SkipSpaceAndTab(str);
} }
......
...@@ -24,8 +24,12 @@ class SparseBin:public Bin { ...@@ -24,8 +24,12 @@ class SparseBin:public Bin {
public: public:
friend class SparseBinIterator<VAL_T>; friend class SparseBinIterator<VAL_T>;
explicit SparseBin(data_size_t num_data) explicit SparseBin(data_size_t num_data, int default_bin)
: num_data_(num_data) { : num_data_(num_data) {
default_bin_ = static_cast<VAL_T>(default_bin);
if (default_bin_ != 0) {
Log::Info("Warning: Having sparse feature with negative values. Will let negative values equal zero as well");
}
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
{ {
...@@ -41,7 +45,7 @@ public: ...@@ -41,7 +45,7 @@ public:
void Push(int tid, data_size_t idx, uint32_t value) override { void Push(int tid, data_size_t idx, uint32_t value) override {
// not store zero data // not store zero data
if (value == 0) { return; } if (value <= default_bin_) { return; }
push_buffers_[tid].emplace_back(idx, static_cast<VAL_T>(value)); push_buffers_[tid].emplace_back(idx, static_cast<VAL_T>(value));
} }
...@@ -50,7 +54,7 @@ public: ...@@ -50,7 +54,7 @@ public:
void ConstructHistogram(data_size_t*, data_size_t , const score_t* , void ConstructHistogram(data_size_t*, data_size_t , const score_t* ,
const score_t* , HistogramBinEntry*) const override { const score_t* , HistogramBinEntry*) const override {
// Will use OrderedSparseBin->ConstructHistogram() instead // Will use OrderedSparseBin->ConstructHistogram() instead
Log::Stderr("Should use OrderedSparseBin->ConstructHistogram() instead"); Log::Info("Should use OrderedSparseBin->ConstructHistogram() instead");
} }
data_size_t Split(unsigned int threshold, data_size_t* data_indices, data_size_t num_data, data_size_t Split(unsigned int threshold, data_size_t* data_indices, data_size_t num_data,
...@@ -240,6 +244,7 @@ private: ...@@ -240,6 +244,7 @@ private:
std::vector<std::vector<std::pair<data_size_t, VAL_T>>> push_buffers_; std::vector<std::vector<std::pair<data_size_t, VAL_T>>> push_buffers_;
std::vector<std::pair<data_size_t, data_size_t>> fast_index_; std::vector<std::pair<data_size_t, data_size_t>> fast_index_;
data_size_t fast_index_shift_; data_size_t fast_index_shift_;
VAL_T default_bin_;
}; };
template <typename VAL_T> template <typename VAL_T>
......
...@@ -28,6 +28,9 @@ Tree::Tree(int max_leaves) ...@@ -28,6 +28,9 @@ Tree::Tree(int max_leaves)
leaf_parent_ = new int[max_leaves_]; leaf_parent_ = new int[max_leaves_];
leaf_value_ = new score_t[max_leaves_]; leaf_value_ = new score_t[max_leaves_];
leaf_depth_ = new int[max_leaves_];
// root is in the depth 1
leaf_depth_[0] = 1;
num_leaves_ = 1; num_leaves_ = 1;
leaf_parent_[0] = -1; leaf_parent_[0] = -1;
} }
...@@ -41,6 +44,7 @@ Tree::~Tree() { ...@@ -41,6 +44,7 @@ Tree::~Tree() {
if (threshold_ != nullptr) { delete[] threshold_; } if (threshold_ != nullptr) { delete[] threshold_; }
if (split_gain_ != nullptr) { delete[] split_gain_; } if (split_gain_ != nullptr) { delete[] split_gain_; }
if (leaf_value_ != nullptr) { delete[] leaf_value_; } if (leaf_value_ != nullptr) { delete[] leaf_value_; }
if (leaf_depth_ != nullptr) { delete[] leaf_depth_; }
} }
int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feature, int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feature,
...@@ -70,9 +74,11 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat ...@@ -70,9 +74,11 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat
leaf_parent_[num_leaves_] = new_node_idx; leaf_parent_[num_leaves_] = new_node_idx;
leaf_value_[leaf] = left_value; leaf_value_[leaf] = left_value;
leaf_value_[num_leaves_] = right_value; leaf_value_[num_leaves_] = right_value;
// update leaf depth
leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
leaf_depth_[leaf]++;
++num_leaves_; ++num_leaves_;
return num_leaves_ - 1; return num_leaves_ - 1;
} }
...@@ -140,7 +146,7 @@ Tree::Tree(const std::string& str) { ...@@ -140,7 +146,7 @@ Tree::Tree(const std::string& str) {
|| key_vals.count("split_gain") <= 0 || key_vals.count("threshold") <= 0 || key_vals.count("split_gain") <= 0 || key_vals.count("threshold") <= 0
|| key_vals.count("left_child") <= 0 || key_vals.count("right_child") <= 0 || key_vals.count("left_child") <= 0 || key_vals.count("right_child") <= 0
|| key_vals.count("leaf_parent") <= 0 || key_vals.count("leaf_value") <= 0) { || key_vals.count("leaf_parent") <= 0 || key_vals.count("leaf_value") <= 0) {
Log::Stderr("tree model string format error"); Log::Fatal("tree model string format error");
} }
Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_); Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);
...@@ -155,6 +161,7 @@ Tree::Tree(const std::string& str) { ...@@ -155,6 +161,7 @@ Tree::Tree(const std::string& str) {
split_feature_ = nullptr; split_feature_ = nullptr;
threshold_in_bin_ = nullptr; threshold_in_bin_ = nullptr;
leaf_depth_ = nullptr;
Common::StringToIntArray(key_vals["split_feature"], ' ', Common::StringToIntArray(key_vals["split_feature"], ' ',
num_leaves_ - 1, split_feature_real_); num_leaves_ - 1, split_feature_real_);
......
...@@ -18,10 +18,12 @@ template<typename PointWiseLossCalculator> ...@@ -18,10 +18,12 @@ template<typename PointWiseLossCalculator>
class BinaryMetric: public Metric { class BinaryMetric: public Metric {
public: public:
explicit BinaryMetric(const MetricConfig& config) { explicit BinaryMetric(const MetricConfig& config) {
early_stopping_round_ = config.early_stopping_round;
output_freq_ = config.output_freq; output_freq_ = config.output_freq;
the_bigger_the_better = false;
sigmoid_ = static_cast<score_t>(config.sigmoid); sigmoid_ = static_cast<score_t>(config.sigmoid);
if (sigmoid_ <= 0.0f) { if (sigmoid_ <= 0.0f) {
Log::Stderr("sigmoid param %f should greater than zero", sigmoid_); Log::Fatal("Sigmoid param %f should greater than zero", sigmoid_);
} }
} }
...@@ -48,14 +50,14 @@ public: ...@@ -48,14 +50,14 @@ public:
} }
} }
void Print(int iter, const score_t* score) const override { score_t PrintAndGetLoss(int iter, const score_t* score) const override {
score_t sum_loss = 0.0f; score_t sum_loss = 0.0f;
if (output_freq_ > 0 && iter % output_freq_ == 0) { if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
if (weights_ == nullptr) { if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// sigmoid transform // sigmoid transform
score_t prob = 1.0f / (1.0f + std::exp(-sigmoid_ * score[i])); score_t prob = 1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * score[i]));
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob); sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob);
} }
...@@ -63,13 +65,18 @@ public: ...@@ -63,13 +65,18 @@ public:
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// sigmoid transform // sigmoid transform
score_t prob = 1.0f / (1.0f + std::exp(-sigmoid_ * score[i])); score_t prob = 1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * score[i]));
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob) * weights_[i]; sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob) * weights_[i];
} }
} }
Log::Stdout("Iteration:%d, %s's %s: %f", iter, name, PointWiseLossCalculator::Name(), sum_loss / sum_weights_); score_t loss = sum_loss / sum_weights_;
if (output_freq_ > 0 && iter % output_freq_ == 0){
Log::Info("Iteration:%d, %s's %s: %f", iter, name, PointWiseLossCalculator::Name(), loss);
}
return loss;
} }
return 0.0f;
} }
private: private:
...@@ -139,7 +146,9 @@ public: ...@@ -139,7 +146,9 @@ public:
class AUCMetric: public Metric { class AUCMetric: public Metric {
public: public:
explicit AUCMetric(const MetricConfig& config) { explicit AUCMetric(const MetricConfig& config) {
early_stopping_round_ = config.early_stopping_round;
output_freq_ = config.output_freq; output_freq_ = config.output_freq;
the_bigger_the_better = true;
} }
virtual ~AUCMetric() { virtual ~AUCMetric() {
...@@ -163,8 +172,8 @@ public: ...@@ -163,8 +172,8 @@ public:
} }
} }
void Print(int iter, const score_t* score) const override { score_t PrintAndGetLoss(int iter, const score_t* score) const override {
if (output_freq_ > 0 && iter % output_freq_ == 0) { if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
// get indices sorted by score, descent order // get indices sorted by score, descent order
std::vector<data_size_t> sorted_idx; std::vector<data_size_t> sorted_idx;
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
...@@ -220,8 +229,12 @@ public: ...@@ -220,8 +229,12 @@ public:
if (sum_pos > 0.0f && sum_pos != sum_weights_) { if (sum_pos > 0.0f && sum_pos != sum_weights_) {
auc = accum / (sum_pos *(sum_weights_ - sum_pos)); auc = accum / (sum_pos *(sum_weights_ - sum_pos));
} }
Log::Stdout("iteration:%d, %s's %s: %f", iter, name, "auc", auc); if (output_freq_ > 0 && iter % output_freq_ == 0){
Log::Info("Iteration:%d, %s's %s: %f", iter, name, "auc", auc);
}
return auc;
} }
return 0.0f;
} }
private: private:
......
...@@ -57,7 +57,7 @@ void DCGCalculator::CalMaxDCG(const std::vector<data_size_t>& ks, ...@@ -57,7 +57,7 @@ void DCGCalculator::CalMaxDCG(const std::vector<data_size_t>& ks,
std::vector<data_size_t> label_cnt(label_gain_.size(), 0); std::vector<data_size_t> label_cnt(label_gain_.size(), 0);
// counts for all labels // counts for all labels
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
if (static_cast<size_t>(label[i]) >= label_cnt.size()) { Log::Stderr("label excel %d\n", label[i]); } if (static_cast<size_t>(label[i]) >= label_cnt.size()) { Log::Fatal("label excel %d", label[i]); }
++label_cnt[static_cast<int>(label[i])]; ++label_cnt[static_cast<int>(label[i])];
} }
double cur_result = 0.0; double cur_result = 0.0;
......
...@@ -16,7 +16,9 @@ namespace LightGBM { ...@@ -16,7 +16,9 @@ namespace LightGBM {
class NDCGMetric:public Metric { class NDCGMetric:public Metric {
public: public:
explicit NDCGMetric(const MetricConfig& config) { explicit NDCGMetric(const MetricConfig& config) {
early_stopping_round_ = config.early_stopping_round;
output_freq_ = config.output_freq; output_freq_ = config.output_freq;
the_bigger_the_better = true;
// get eval position // get eval position
for (auto k : config.eval_at) { for (auto k : config.eval_at) {
eval_at_.push_back(static_cast<data_size_t>(k)); eval_at_.push_back(static_cast<data_size_t>(k));
...@@ -41,7 +43,7 @@ public: ...@@ -41,7 +43,7 @@ public:
// get query boundaries // get query boundaries
query_boundaries_ = metadata.query_boundaries(); query_boundaries_ = metadata.query_boundaries();
if (query_boundaries_ == nullptr) { if (query_boundaries_ == nullptr) {
Log::Stderr("For NDCG metric, should have query information"); Log::Fatal("For NDCG metric, there should be query information");
} }
num_queries_ = metadata.num_queries(); num_queries_ = metadata.num_queries();
// get query weights // get query weights
...@@ -73,8 +75,8 @@ public: ...@@ -73,8 +75,8 @@ public:
} }
} }
void Print(int iter, const score_t* score) const override { score_t PrintAndGetLoss(int iter, const score_t* score) const override {
if (output_freq_ > 0 && iter % output_freq_ == 0) { if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
// some buffers for multi-threading sum up // some buffers for multi-threading sum up
std::vector<std::vector<double>> result_buffer_; std::vector<std::vector<double>> result_buffer_;
for (int i = 0; i < num_threads_; ++i) { for (int i = 0; i < num_threads_; ++i) {
...@@ -132,8 +134,12 @@ public: ...@@ -132,8 +134,12 @@ public:
result[j] /= sum_query_weights_; result[j] /= sum_query_weights_;
result_ss << "NDCG@" << eval_at_[j] << ":" << result[j] << "\t"; result_ss << "NDCG@" << eval_at_[j] << ":" << result[j] << "\t";
} }
Log::Stdout("Iteration:%d, Test:%s, %s ", iter, name, result_ss.str().c_str()); if (output_freq_ > 0 && iter % output_freq_ == 0){
Log::Info("Iteration:%d, Test:%s, %s ", iter, name, result_ss.str().c_str());
}
return result[0];
} }
return 0.0f;
} }
private: private:
......
...@@ -16,7 +16,9 @@ template<typename PointWiseLossCalculator> ...@@ -16,7 +16,9 @@ template<typename PointWiseLossCalculator>
class RegressionMetric: public Metric { class RegressionMetric: public Metric {
public: public:
explicit RegressionMetric(const MetricConfig& config) { explicit RegressionMetric(const MetricConfig& config) {
early_stopping_round_ = config.early_stopping_round;
output_freq_ = config.output_freq; output_freq_ = config.output_freq;
the_bigger_the_better = false;
} }
virtual ~RegressionMetric() { virtual ~RegressionMetric() {
...@@ -39,9 +41,9 @@ public: ...@@ -39,9 +41,9 @@ public:
} }
} }
} }
void Print(int iter, const score_t* score) const override { score_t PrintAndGetLoss(int iter, const score_t* score) const override {
if (output_freq_ > 0 && iter % output_freq_ == 0) { if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
score_t sum_loss = 0.0; score_t sum_loss = 0.0;
if (weights_ == nullptr) { if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
...@@ -56,8 +58,13 @@ public: ...@@ -56,8 +58,13 @@ public:
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i]) * weights_[i]; sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i]) * weights_[i];
} }
} }
Log::Stdout("Iteration:%d, %s's %s : %f", iter, name, PointWiseLossCalculator::Name(), PointWiseLossCalculator::AverageLoss(sum_loss, sum_weights_)); score_t loss = PointWiseLossCalculator::AverageLoss(sum_loss, sum_weights_);
if (output_freq_ > 0 && iter % output_freq_ == 0){
Log::Info("Iteration:%d, %s's %s : %f", iter, name, PointWiseLossCalculator::Name(), loss);
}
return loss;
} }
return 0.0f;
} }
inline static score_t AverageLoss(score_t sum_loss, score_t sum_weights) { inline static score_t AverageLoss(score_t sum_loss, score_t sum_weights) {
......
...@@ -44,7 +44,7 @@ Linkers::Linkers(NetworkConfig config) { ...@@ -44,7 +44,7 @@ Linkers::Linkers(NetworkConfig config) {
} }
} }
if (rank_ == -1) { if (rank_ == -1) {
Log::Stderr("machine list file doesn't contain local machine, app quit"); Log::Fatal("Machine list file doesn't contain local machine");
} }
// construct listener // construct listener
listener_ = new TcpSocket(); listener_ = new TcpSocket();
...@@ -73,14 +73,14 @@ Linkers::~Linkers() { ...@@ -73,14 +73,14 @@ Linkers::~Linkers() {
} }
} }
TcpSocket::Finalize(); TcpSocket::Finalize();
Log::Stdout("network used %f seconds", network_time_ * 1e-3); Log::Info("Network using %f seconds", network_time_ * 1e-3);
} }
void Linkers::ParseMachineList(const char * filename) { void Linkers::ParseMachineList(const char * filename) {
TextReader<size_t> machine_list_reader(filename); TextReader<size_t> machine_list_reader(filename);
machine_list_reader.ReadAllLines(); machine_list_reader.ReadAllLines();
if (machine_list_reader.Lines().size() <= 0) { if (machine_list_reader.Lines().size() <= 0) {
Log::Stderr("machine list file:%s doesn't exist", filename); Log::Fatal("Machine list file:%s doesn't exist", filename);
} }
for (auto& line : machine_list_reader.Lines()) { for (auto& line : machine_list_reader.Lines()) {
...@@ -95,7 +95,7 @@ void Linkers::ParseMachineList(const char * filename) { ...@@ -95,7 +95,7 @@ void Linkers::ParseMachineList(const char * filename) {
continue; continue;
} }
if (client_ips_.size() >= static_cast<size_t>(num_machines_)) { if (client_ips_.size() >= static_cast<size_t>(num_machines_)) {
Log::Stdout("The #machine in machine list is larger than parameter num_machines, will ignore rest"); Log::Error("The #machine in machine_list is larger than parameter num_machines, the redundant will ignored");
break; break;
} }
str_after_split[0] = Common::Trim(str_after_split[0]); str_after_split[0] = Common::Trim(str_after_split[0]);
...@@ -104,17 +104,17 @@ void Linkers::ParseMachineList(const char * filename) { ...@@ -104,17 +104,17 @@ void Linkers::ParseMachineList(const char * filename) {
client_ports_.push_back(atoi(str_after_split[1].c_str())); client_ports_.push_back(atoi(str_after_split[1].c_str()));
} }
if (client_ips_.size() != static_cast<size_t>(num_machines_)) { if (client_ips_.size() != static_cast<size_t>(num_machines_)) {
Log::Stdout("The world size is bigger the #machine in machine list, change world size to %d .", client_ips_.size()); Log::Error("The world size is bigger the #machine in machine list, change world size to %d .", client_ips_.size());
num_machines_ = static_cast<int>(client_ips_.size()); num_machines_ = static_cast<int>(client_ips_.size());
} }
} }
void Linkers::TryBind(int port) { void Linkers::TryBind(int port) {
Log::Stdout("try to bind port %d.", port); Log::Info("try to bind port %d.", port);
if (listener_->Bind(port)) { if (listener_->Bind(port)) {
Log::Stdout("bind port %d success.", port); Log::Info("Binding port %d success.", port);
} else { } else {
Log::Stderr("bind port %d failed.", port); Log::Fatal("Binding port %d failed.", port);
} }
} }
...@@ -125,7 +125,7 @@ void Linkers::SetLinker(int rank, const TcpSocket& socket) { ...@@ -125,7 +125,7 @@ void Linkers::SetLinker(int rank, const TcpSocket& socket) {
} }
void Linkers::ListenThread(int incoming_cnt) { void Linkers::ListenThread(int incoming_cnt) {
Log::Stdout("Listening..."); Log::Info("Listening...");
char buffer[100]; char buffer[100];
int connected_cnt = 0; int connected_cnt = 0;
while (connected_cnt < incoming_cnt) { while (connected_cnt < incoming_cnt) {
...@@ -192,7 +192,7 @@ void Linkers::Construct() { ...@@ -192,7 +192,7 @@ void Linkers::Construct() {
if (cur_socket.Connect(client_ips_[out_rank].c_str(), client_ports_[out_rank])) { if (cur_socket.Connect(client_ips_[out_rank].c_str(), client_ports_[out_rank])) {
break; break;
} else { } else {
Log::Stdout("connect to rank %d failed, wait for %d milliseconds", out_rank, connect_fail_delay_time); Log::Error("Connect to rank %d failed, wait for %d milliseconds", out_rank, connect_fail_delay_time);
std::this_thread::sleep_for(std::chrono::milliseconds(connect_fail_delay_time)); std::this_thread::sleep_for(std::chrono::milliseconds(connect_fail_delay_time));
} }
} }
...@@ -217,7 +217,7 @@ bool Linkers::CheckLinker(int rank) { ...@@ -217,7 +217,7 @@ bool Linkers::CheckLinker(int rank) {
void Linkers::PrintLinkers() { void Linkers::PrintLinkers() {
for (int i = 0; i < num_machines_; ++i) { for (int i = 0; i < num_machines_; ++i) {
if (CheckLinker(i)) { if (CheckLinker(i)) {
Log::Stdout("Connected to rank %d.", i); Log::Info("Connected to rank %d.", i);
} }
} }
} }
......
...@@ -30,7 +30,7 @@ void Network::Init(NetworkConfig config) { ...@@ -30,7 +30,7 @@ void Network::Init(NetworkConfig config) {
block_len_ = new int[num_machines_]; block_len_ = new int[num_machines_];
buffer_size_ = 1024 * 1024; buffer_size_ = 1024 * 1024;
buffer_ = new char[buffer_size_]; buffer_ = new char[buffer_size_];
Log::Stdout("local rank %d, total number of machines %d", rank_, num_machines_); Log::Info("local rank %d, total number of machines %d", rank_, num_machines_);
} }
void Network::Dispose() { void Network::Dispose() {
......
...@@ -60,7 +60,7 @@ public: ...@@ -60,7 +60,7 @@ public:
TcpSocket() { TcpSocket() {
sockfd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); sockfd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (sockfd_ == INVALID_SOCKET) { if (sockfd_ == INVALID_SOCKET) {
Log::Stderr("socket construct error"); Log::Fatal("Socket construct error");
return; return;
} }
ConfigSocket(); ConfigSocket();
...@@ -69,7 +69,7 @@ public: ...@@ -69,7 +69,7 @@ public:
explicit TcpSocket(SOCKET socket) { explicit TcpSocket(SOCKET socket) {
sockfd_ = socket; sockfd_ = socket;
if (sockfd_ == INVALID_SOCKET) { if (sockfd_ == INVALID_SOCKET) {
Log::Stderr("passed socket error"); Log::Fatal("Passed socket error");
return; return;
} }
ConfigSocket(); ConfigSocket();
...@@ -97,11 +97,11 @@ public: ...@@ -97,11 +97,11 @@ public:
#if defined(_WIN32) #if defined(_WIN32)
WSADATA wsa_data; WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
Log::Stderr("socket error: start up error"); Log::Fatal("Socket error: WSAStart up error");
} }
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
WSACleanup(); WSACleanup();
Log::Stderr("socket error: Winsock.dll version error"); Log::Fatal("Socket error: Winsock.dll version error");
} }
#else #else
#endif #endif
...@@ -128,7 +128,7 @@ public: ...@@ -128,7 +128,7 @@ public:
char buffer[512]; char buffer[512];
// get hostName // get hostName
if (gethostname(buffer, sizeof(buffer)) == SOCKET_ERROR) { if (gethostname(buffer, sizeof(buffer)) == SOCKET_ERROR) {
Log::Stderr("Error code: %d, when getting local host name.", WSAGetLastError()); Log::Fatal("Error code: %d, when getting local host name.", WSAGetLastError());
} }
// push local ip // push local ip
PIP_ADAPTER_INFO pAdapterInfo; PIP_ADAPTER_INFO pAdapterInfo;
...@@ -137,7 +137,7 @@ public: ...@@ -137,7 +137,7 @@ public:
ULONG ulOutBufLen = sizeof(IP_ADAPTER_INFO); ULONG ulOutBufLen = sizeof(IP_ADAPTER_INFO);
pAdapterInfo = (IP_ADAPTER_INFO *)MALLOC(sizeof(IP_ADAPTER_INFO)); pAdapterInfo = (IP_ADAPTER_INFO *)MALLOC(sizeof(IP_ADAPTER_INFO));
if (pAdapterInfo == NULL) { if (pAdapterInfo == NULL) {
Log::Stderr("Error allocating memory needed to call GetAdaptersinfo\n"); Log::Fatal("GetAdaptersinfo error: allocating memory ");
} }
// Make an initial call to GetAdaptersInfo to get // Make an initial call to GetAdaptersInfo to get
// the necessary size into the ulOutBufLen variable // the necessary size into the ulOutBufLen variable
...@@ -145,7 +145,7 @@ public: ...@@ -145,7 +145,7 @@ public:
FREE(pAdapterInfo); FREE(pAdapterInfo);
pAdapterInfo = (IP_ADAPTER_INFO *)MALLOC(ulOutBufLen); pAdapterInfo = (IP_ADAPTER_INFO *)MALLOC(ulOutBufLen);
if (pAdapterInfo == NULL) { if (pAdapterInfo == NULL) {
Log::Stderr("Error allocating memory needed to call GetAdaptersinfo\n"); Log::Fatal("GetAdaptersinfo error: allocating memory ");
} }
} }
if ((dwRetVal = GetAdaptersInfo(pAdapterInfo, &ulOutBufLen)) == NO_ERROR) { if ((dwRetVal = GetAdaptersInfo(pAdapterInfo, &ulOutBufLen)) == NO_ERROR) {
...@@ -155,7 +155,7 @@ public: ...@@ -155,7 +155,7 @@ public:
pAdapter = pAdapter->Next; pAdapter = pAdapter->Next;
} }
} else { } else {
printf("GetAdaptersInfo failed with error: %d\n", dwRetVal); Log::Error("GetAdaptersinfo error: code %d ", dwRetVal);
} }
if (pAdapterInfo) if (pAdapterInfo)
FREE(pAdapterInfo); FREE(pAdapterInfo);
...@@ -218,7 +218,7 @@ public: ...@@ -218,7 +218,7 @@ public:
inline TcpSocket Accept() { inline TcpSocket Accept() {
SOCKET newfd = accept(sockfd_, NULL, NULL); SOCKET newfd = accept(sockfd_, NULL, NULL);
if (newfd == INVALID_SOCKET) { if (newfd == INVALID_SOCKET) {
Log::Stderr("socket accept error,error code: %d", GetLastError()); Log::Fatal("Socket accept error, code: %d", GetLastError());
} }
return TcpSocket(newfd); return TcpSocket(newfd);
} }
...@@ -226,7 +226,7 @@ public: ...@@ -226,7 +226,7 @@ public:
inline int Send(const char *buf_, int len, int flag = 0) { inline int Send(const char *buf_, int len, int flag = 0) {
int cur_cnt = send(sockfd_, buf_, len, flag); int cur_cnt = send(sockfd_, buf_, len, flag);
if (cur_cnt == SOCKET_ERROR) { if (cur_cnt == SOCKET_ERROR) {
Log::Stderr("socket send error, error code: %d", GetLastError()); Log::Fatal("Socket send error, code: %d", GetLastError());
} }
return cur_cnt; return cur_cnt;
} }
...@@ -234,7 +234,7 @@ public: ...@@ -234,7 +234,7 @@ public:
inline int Recv(char *buf_, int len, int flags = 0) { inline int Recv(char *buf_, int len, int flags = 0) {
int cur_cnt = recv(sockfd_, buf_ , len , flags); int cur_cnt = recv(sockfd_, buf_ , len , flags);
if (cur_cnt == SOCKET_ERROR) { if (cur_cnt == SOCKET_ERROR) {
Log::Stderr("socket recv error, error code: %d", GetLastError()); Log::Fatal("Socket recv error, code: %d", GetLastError());
} }
return cur_cnt; return cur_cnt;
} }
......
...@@ -16,7 +16,7 @@ public: ...@@ -16,7 +16,7 @@ public:
is_unbalance_ = config.is_unbalance; is_unbalance_ = config.is_unbalance;
sigmoid_ = static_cast<score_t>(config.sigmoid); sigmoid_ = static_cast<score_t>(config.sigmoid);
if (sigmoid_ <= 0.0) { if (sigmoid_ <= 0.0) {
Log::Stderr("sigmoid param %f should greater than zero", sigmoid_); Log::Fatal("Sigmoid parameter %f :should greater than zero", sigmoid_);
} }
} }
~BinaryLogloss() {} ~BinaryLogloss() {}
...@@ -34,10 +34,10 @@ public: ...@@ -34,10 +34,10 @@ public:
++cnt_negative; ++cnt_negative;
} }
} }
Log::Stdout("number of postive:%d number of negative:%d", cnt_positive, cnt_negative); Log::Info("Number of postive:%d, number of negative:%d", cnt_positive, cnt_negative);
// cannot continue if all sample are same class // cannot continue if all sample are same class
if (cnt_positive == 0 || cnt_negative == 0) { if (cnt_positive == 0 || cnt_negative == 0) {
Log::Stderr("input training data only contain one class"); Log::Fatal("Input training data only contains one class");
} }
// use -1 for negative class, and 1 for positive class // use -1 for negative class, and 1 for positive class
label_val_[0] = -1; label_val_[0] = -1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment