Commit ab5ab8a7 authored by Qiwei Ye's avatar Qiwei Ye
Browse files

Merge branch 'master' of https://github.com/Microsoft/LightGBM

Conflicts:
	src/io/dataset.cpp
parents 7d25c253 0241d323
...@@ -2,21 +2,21 @@ LightGBM, Light Gradient Boosting Machine ...@@ -2,21 +2,21 @@ LightGBM, Light Gradient Boosting Machine
========== ==========
[![Build Status](https://travis-ci.org/Microsoft/LightGBM.svg?branch=master)](https://travis-ci.org/Microsoft/LightGBM) [![Build Status](https://travis-ci.org/Microsoft/LightGBM.svg?branch=master)](https://travis-ci.org/Microsoft/LightGBM)
LightGBM is a gradient boosting framework that is using tree based learning algorithms. It is designed to be distributed and efficient with following advantages: LightGBM is a gradient boosting framework that uses tree based learning algorithms. It is designed to be distributed and efficient with the following advantages:
- Fast training speed and high efficiency - Faster training speed and higher efficiency
- Lower memory usage - Lower memory usage
- Better accuracy - Better accuracy
- Parallel learning supported - Parallel learning supported
- Capability of handling large-scaling data - Capable of handling large-scale data
For more details, please refer to [Features](https://github.com/Microsoft/LightGBM/wiki/Features). For more details, please refer to [Features](https://github.com/Microsoft/LightGBM/wiki/Features).
The [experiments](https://github.com/Microsoft/LightGBM/wiki/Experiments#comparison-experiment) on public datasets show that LightGBM outperform other existing boosting tools on both efficiency and accuracy, with significant lower memory consumption. What's more, the [experiments](https://github.com/Microsoft/LightGBM/wiki/Experiments#parallel-experiment) show that LightGBM can achieve linear speed-up by using multiple machines for training in specific settings. [Experiments](https://github.com/Microsoft/LightGBM/wiki/Experiments#comparison-experiment) on public datasets show that LightGBM can outperform other existing boosting framework on both efficiency and accuracy, with significant lower memory consumption. What's more, the [experiments](https://github.com/Microsoft/LightGBM/wiki/Experiments#parallel-experiment) show that LightGBM can achieve a linear speed-up by using multiple machines for training in specific settings.
Get Started Get Started
------------ ------------
For a quick start, please follow the [Installation Guide](https://github.com/Microsoft/LightGBM/wiki/Installation-Guide) and [Quick Start](https://github.com/Microsoft/LightGBM/wiki/Quick-Start). To get started, please follow the [Installation Guide](https://github.com/Microsoft/LightGBM/wiki/Installation-Guide) and [Quick Start](https://github.com/Microsoft/LightGBM/wiki/Quick-Start).
Documents Documents
------------ ------------
...@@ -28,8 +28,6 @@ Documents ...@@ -28,8 +28,6 @@ Documents
* [**Parallel Learning Guide**](https://github.com/Microsoft/LightGBM/wiki/Parallel-Learning-Guide) * [**Parallel Learning Guide**](https://github.com/Microsoft/LightGBM/wiki/Parallel-Learning-Guide)
* [**Configuration**](https://github.com/Microsoft/LightGBM/wiki/Configuration) * [**Configuration**](https://github.com/Microsoft/LightGBM/wiki/Configuration)
Microsoft Open Source Code of Conduct Microsoft Open Source Code of Conduct
------------ ------------
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
Multiclass Classification Example
=====================
Here is an example for LightGBM to run multiclass classification task.
***You should copy executable file to this folder first.***
#### Training
For windows, by running following command in this folder:
```
lightgbm.exe config=train.conf
```
For linux, by running following command in this folder:
```
./lightgbm config=train.conf
```
#### Prediction
You should finish training first.
For windows, by running following command in this folder:
```
lightgbm.exe config=predict.conf
```
For linux, by running following command in this folder:
```
./lightgbm config=predict.conf
```
...@@ -55,37 +55,31 @@ public: ...@@ -55,37 +55,31 @@ public:
/*! /*!
* \brief Prediction for one record, not sigmoid transform * \brief Prediction for one record, not sigmoid transform
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record * \return Prediction result for this record
*/ */
virtual double PredictRaw(const double* feature_values, virtual double PredictRaw(const double* feature_values) const = 0;
int num_used_model) const = 0;
/*! /*!
* \brief Prediction for one record, sigmoid transformation will be used if needed * \brief Prediction for one record, sigmoid transformation will be used if needed
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record * \return Prediction result for this record
*/ */
virtual double Predict(const double* feature_values, virtual double Predict(const double* feature_values) const = 0;
int num_used_model) const = 0;
/*! /*!
* \brief Predtion for one record with leaf index * \brief Predtion for one record with leaf index
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Predicted leaf index for this record * \return Predicted leaf index for this record
*/ */
virtual std::vector<int> PredictLeafIndex( virtual std::vector<int> PredictLeafIndex(
const double* feature_values, const double* feature_values) const = 0;
int num_used_model) const = 0;
/*! /*!
* \brief Predtion for multiclass classification * \brief Predtion for multiclass classification
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result, num_class numbers per line * \return Prediction result, num_class numbers per line
*/ */
virtual std::vector<double> PredictMulticlass(const double* value, int num_used_model) const = 0; virtual std::vector<double> PredictMulticlass(const double* value) const = 0;
/*! /*!
* \brief save model to file * \brief save model to file
...@@ -122,6 +116,11 @@ public: ...@@ -122,6 +116,11 @@ public:
*/ */
virtual int NumberOfClass() const = 0; virtual int NumberOfClass() const = 0;
/*!
* \brief Set number of used model for prediction
*/
virtual void SetNumUsedModel(int num_used_model) = 0;
/*! /*!
* \brief Get Type name of this boosting object * \brief Get Type name of this boosting object
*/ */
......
...@@ -86,6 +86,7 @@ enum TaskType { ...@@ -86,6 +86,7 @@ enum TaskType {
struct IOConfig: public ConfigBase { struct IOConfig: public ConfigBase {
public: public:
int max_bin = 256; int max_bin = 256;
int num_class = 1;
int data_random_seed = 1; int data_random_seed = 1;
std::string data_filename = ""; std::string data_filename = "";
std::vector<std::string> valid_data_filenames; std::vector<std::string> valid_data_filenames;
...@@ -258,7 +259,7 @@ inline bool ConfigBase::GetInt( ...@@ -258,7 +259,7 @@ inline bool ConfigBase::GetInt(
const std::string& name, int* out) { const std::string& name, int* out) {
if (params.count(name) > 0) { if (params.count(name) > 0) {
if (!Common::AtoiAndCheck(params.at(name).c_str(), out)) { if (!Common::AtoiAndCheck(params.at(name).c_str(), out)) {
Log::Fatal("Parameter %s should be int type, passed is [%s]", Log::Fatal("Parameter %s should be of type int, got [%s]",
name.c_str(), params.at(name).c_str()); name.c_str(), params.at(name).c_str());
} }
return true; return true;
...@@ -271,7 +272,7 @@ inline bool ConfigBase::GetDouble( ...@@ -271,7 +272,7 @@ inline bool ConfigBase::GetDouble(
const std::string& name, double* out) { const std::string& name, double* out) {
if (params.count(name) > 0) { if (params.count(name) > 0) {
if (!Common::AtofAndCheck(params.at(name).c_str(), out)) { if (!Common::AtofAndCheck(params.at(name).c_str(), out)) {
Log::Fatal("Parameter %s should be double type, passed is [%s]", Log::Fatal("Parameter %s should be of type double, got [%s]",
name.c_str(), params.at(name).c_str()); name.c_str(), params.at(name).c_str());
} }
return true; return true;
...@@ -290,7 +291,7 @@ inline bool ConfigBase::GetBool( ...@@ -290,7 +291,7 @@ inline bool ConfigBase::GetBool(
} else if (value == std::string("true") || value == std::string("+")) { } else if (value == std::string("true") || value == std::string("+")) {
*out = true; *out = true;
} else { } else {
Log::Fatal("Parameter %s should be \"true\"/\"+\" or \"false\"/\"-\", passed is [%s]", Log::Fatal("Parameter %s should be \"true\"/\"+\" or \"false\"/\"-\", got [%s]",
name.c_str(), params.at(name).c_str()); name.c_str(), params.at(name).c_str());
} }
return true; return true;
......
...@@ -41,14 +41,15 @@ public: ...@@ -41,14 +41,15 @@ public:
* \brief Initialization will load qurey level informations, since it is need for sampling data * \brief Initialization will load qurey level informations, since it is need for sampling data
* \param data_filename Filename of data * \param data_filename Filename of data
* \param init_score_filename Filename of initial score * \param init_score_filename Filename of initial score
* \param is_int_label True if label is int type * \param num_class Number of classes
*/ */
void Init(const char* data_filename, const char* init_score_filename); void Init(const char* data_filename, const char* init_score_filename, const int num_class);
/*! /*!
* \brief Initialize, only load initial score * \brief Initialize, only load initial score
* \param init_score_filename Filename of initial score * \param init_score_filename Filename of initial score
* \param num_class Number of classes
*/ */
void Init(const char* init_score_filename); void Init(const char* init_score_filename, const int num_class);
/*! /*!
* \brief Initial with binary memory * \brief Initial with binary memory
* \param memory Pointer to memory * \param memory Pointer to memory
...@@ -60,10 +61,11 @@ public: ...@@ -60,10 +61,11 @@ public:
/*! /*!
* \brief Initial work, will allocate space for label, weight(if exists) and query(if exists) * \brief Initial work, will allocate space for label, weight(if exists) and query(if exists)
* \param num_data Number of training data * \param num_data Number of training data
* \param num_class Number of classes
* \param weight_idx Index of weight column, < 0 means doesn't exists * \param weight_idx Index of weight column, < 0 means doesn't exists
* \param query_idx Index of query id column, < 0 means doesn't exists * \param query_idx Index of query id column, < 0 means doesn't exists
*/ */
void Init(data_size_t num_data, int weight_idx, int query_idx); void Init(data_size_t num_data, int num_class, int weight_idx, int query_idx);
/*! /*!
* \brief Partition label by used indices * \brief Partition label by used indices
...@@ -128,7 +130,7 @@ public: ...@@ -128,7 +130,7 @@ public:
* \param idx Index of this record * \param idx Index of this record
* \param value Query Id value of this record * \param value Query Id value of this record
*/ */
inline void SetQueryAt(data_size_t idx, float value) inline void SetQueryAt(data_size_t idx, data_size_t value)
{ {
queries_[idx] = static_cast<data_size_t>(value); queries_[idx] = static_cast<data_size_t>(value);
} }
...@@ -184,6 +186,8 @@ private: ...@@ -184,6 +186,8 @@ private:
const char* init_score_filename_; const char* init_score_filename_;
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
/*! \brief Number of weights, used to check correct weight file */ /*! \brief Number of weights, used to check correct weight file */
data_size_t num_weights_; data_size_t num_weights_;
/*! \brief Label data */ /*! \brief Label data */
...@@ -234,7 +238,7 @@ public: ...@@ -234,7 +238,7 @@ public:
}; };
using PredictFunction = using PredictFunction =
std::function<double(const std::vector<std::pair<int, double>>&)>; std::function<std::vector<double>(const std::vector<std::pair<int, double>>&)>;
/*! \brief The main class of data set, /*! \brief The main class of data set,
* which are used to traning or validation * which are used to traning or validation
...@@ -398,6 +402,8 @@ private: ...@@ -398,6 +402,8 @@ private:
int num_total_features_; int num_total_features_;
/*! \brief Number of total data*/ /*! \brief Number of total data*/
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of classes*/
int num_class_;
/*! \brief Store some label level data*/ /*! \brief Store some label level data*/
Metadata metadata_; Metadata metadata_;
/*! \brief Random generator*/ /*! \brief Random generator*/
......
...@@ -179,7 +179,7 @@ inline static const char* Atof(const char* p, double* out) { ...@@ -179,7 +179,7 @@ inline static const char* Atof(const char* p, double* out) {
} else if (tmp_str == std::string("inf") || tmp_str == std::string("infinity")) { } else if (tmp_str == std::string("inf") || tmp_str == std::string("infinity")) {
*out = sign * 1e308; *out = sign * 1e308;
} else { } else {
Log::Fatal("Unknow token %s in data file", tmp_str.c_str()); Log::Fatal("Unknown token %s in data file", tmp_str.c_str());
} }
p += cnt; p += cnt;
} }
...@@ -255,7 +255,7 @@ inline static std::string ArrayToString(std::vector<T> arr, char delimiter) { ...@@ -255,7 +255,7 @@ inline static std::string ArrayToString(std::vector<T> arr, char delimiter) {
inline static void StringToIntArray(const std::string& str, char delimiter, size_t n, int* out) { inline static void StringToIntArray(const std::string& str, char delimiter, size_t n, int* out) {
std::vector<std::string> strs = Split(str.c_str(), delimiter); std::vector<std::string> strs = Split(str.c_str(), delimiter);
if (strs.size() != n) { if (strs.size() != n) {
Log::Fatal("StringToIntArray error, size doesn't matched."); Log::Fatal("StringToIntArray error, size doesn't match.");
} }
for (size_t i = 0; i < strs.size(); ++i) { for (size_t i = 0; i < strs.size(); ++i) {
strs[i] = Trim(strs[i]); strs[i] = Trim(strs[i]);
...@@ -267,7 +267,7 @@ inline static void StringToIntArray(const std::string& str, char delimiter, size ...@@ -267,7 +267,7 @@ inline static void StringToIntArray(const std::string& str, char delimiter, size
inline static void StringToDoubleArray(const std::string& str, char delimiter, size_t n, double* out) { inline static void StringToDoubleArray(const std::string& str, char delimiter, size_t n, double* out) {
std::vector<std::string> strs = Split(str.c_str(), delimiter); std::vector<std::string> strs = Split(str.c_str(), delimiter);
if (strs.size() != n) { if (strs.size() != n) {
Log::Fatal("StringToDoubleArray error, size doesn't matched."); Log::Fatal("StringToDoubleArray error, size doesn't match.");
} }
for (size_t i = 0; i < strs.size(); ++i) { for (size_t i = 0; i < strs.size(); ++i) {
strs[i] = Trim(strs[i]); strs[i] = Trim(strs[i]);
......
...@@ -34,6 +34,9 @@ public: ...@@ -34,6 +34,9 @@ public:
#else #else
file = fopen(filename, "r"); file = fopen(filename, "r");
#endif #endif
if (file == NULL) {
Log::Fatal("Could not open %s", filename);
}
std::stringstream str_buf; std::stringstream str_buf;
int read_c = -1; int read_c = -1;
read_c = fgetc(file); read_c = fgetc(file);
...@@ -56,7 +59,7 @@ public: ...@@ -56,7 +59,7 @@ public:
} }
fclose(file); fclose(file);
first_line_ = str_buf.str(); first_line_ = str_buf.str();
Log::Debug("skip header:\"%s\" in file %s", first_line_.c_str(), filename_); Log::Debug("Skipped header \"%s\" in file %s", first_line_.c_str(), filename_);
} }
} }
/*! /*!
...@@ -126,7 +129,7 @@ public: ...@@ -126,7 +129,7 @@ public:
}); });
// if last line of file doesn't contain end of line // if last line of file doesn't contain end of line
if (last_line_.size() > 0) { if (last_line_.size() > 0) {
Log::Info("Warning: last line of file %s doesn't contain end of line, application will still use this line", filename_); Log::Info("Warning: last line of %s has no end of line, still using this line", filename_);
process_fun(total_cnt, last_line_.c_str(), last_line_.size()); process_fun(total_cnt, last_line_.c_str(), last_line_.size());
++total_cnt; ++total_cnt;
last_line_ = ""; last_line_ = "";
...@@ -263,7 +266,7 @@ public: ...@@ -263,7 +266,7 @@ public:
}); });
// if last line of file doesn't contain end of line // if last line of file doesn't contain end of line
if (last_line_.size() > 0) { if (last_line_.size() > 0) {
Log::Info("Warning: last line of file %s doesn't contain end of line, application will still use this line", filename_); Log::Info("Warning: last line of %s has no end of line, still using this line", filename_);
if (filter_fun(used_cnt, total_cnt)) { if (filter_fun(used_cnt, total_cnt)) {
lines_.push_back(last_line_); lines_.push_back(last_line_);
process_fun(used_cnt, lines_); process_fun(used_cnt, lines_);
......
...@@ -95,7 +95,7 @@ void Application::LoadParameters(int argc, char** argv) { ...@@ -95,7 +95,7 @@ void Application::LoadParameters(int argc, char** argv) {
if (key.size() <= 0) { if (key.size() <= 0) {
continue; continue;
} }
// Command line have higher priority // Command-line has higher priority
if (params.count(key) == 0) { if (params.count(key) == 0) {
params[key] = value; params[key] = value;
} }
...@@ -105,7 +105,7 @@ void Application::LoadParameters(int argc, char** argv) { ...@@ -105,7 +105,7 @@ void Application::LoadParameters(int argc, char** argv) {
} }
} }
} else { } else {
Log::Warning("Config file: %s doesn't exist, will ignore", Log::Warning("Config file %s doesn't exist, will ignore",
params["config_file"].c_str()); params["config_file"].c_str());
} }
} }
...@@ -113,21 +113,28 @@ void Application::LoadParameters(int argc, char** argv) { ...@@ -113,21 +113,28 @@ void Application::LoadParameters(int argc, char** argv) {
ParameterAlias::KeyAliasTransform(&params); ParameterAlias::KeyAliasTransform(&params);
// load configs // load configs
config_.Set(params); config_.Set(params);
Log::Info("Loading parameters .. finished"); Log::Info("Finished loading parameters");
} }
void Application::LoadData() { void Application::LoadData() {
auto start_time = std::chrono::high_resolution_clock::now(); auto start_time = std::chrono::high_resolution_clock::now();
// predition is needed if using input initial model(continued train) // prediction is needed if using input initial model(continued train)
PredictFunction predict_fun = nullptr; PredictFunction predict_fun = nullptr;
Predictor* predictor = nullptr; Predictor* predictor = nullptr;
// need to continue train // need to continue training
if (boosting_->NumberOfSubModels() > 0) { if (boosting_->NumberOfSubModels() > 0) {
predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index, -1); predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index);
if (config_.io_config.num_class == 1){
predict_fun = predict_fun =
[&predictor](const std::vector<std::pair<int, double>>& features) { [&predictor](const std::vector<std::pair<int, double>>& features) {
return predictor->PredictRawOneLine(features); return predictor->PredictRawOneLine(features);
}; };
} else {
predict_fun =
[&predictor](const std::vector<std::pair<int, double>>& features) {
return predictor->PredictMulticlassOneLine(features);
};
}
} }
// sync up random seed for data partition // sync up random seed for data partition
if (config_.is_parallel_find_bin) { if (config_.is_parallel_find_bin) {
...@@ -163,7 +170,7 @@ void Application::LoadData() { ...@@ -163,7 +170,7 @@ void Application::LoadData() {
train_metric_.push_back(metric); train_metric_.push_back(metric);
} }
} }
// Add validation data, if exists // Add validation data, if it exists
for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) { for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) {
// add // add
valid_datas_.push_back( valid_datas_.push_back(
...@@ -194,7 +201,7 @@ void Application::LoadData() { ...@@ -194,7 +201,7 @@ void Application::LoadData() {
} }
auto end_time = std::chrono::high_resolution_clock::now(); auto end_time = std::chrono::high_resolution_clock::now();
// output used time on each iteration // output used time on each iteration
Log::Info("Finish loading data, use %f seconds", Log::Info("Finished loading data in %f seconds",
std::chrono::duration<double, std::milli>(end_time - start_time) * 1e-3); std::chrono::duration<double, std::milli>(end_time - start_time) * 1e-3);
} }
...@@ -202,7 +209,7 @@ void Application::InitTrain() { ...@@ -202,7 +209,7 @@ void Application::InitTrain() {
if (config_.is_parallel) { if (config_.is_parallel) {
// need init network // need init network
Network::Init(config_.network_config); Network::Init(config_.network_config);
Log::Info("Finish network initialization"); Log::Info("Finished initializing network");
// sync global random seed for feature patition // sync global random seed for feature patition
if (config_.boosting_type == BoostingType::kGBDT) { if (config_.boosting_type == BoostingType::kGBDT) {
GBDTConfig* gbdt_config = GBDTConfig* gbdt_config =
...@@ -233,11 +240,11 @@ void Application::InitTrain() { ...@@ -233,11 +240,11 @@ void Application::InitTrain() {
boosting_->AddDataset(valid_datas_[i], boosting_->AddDataset(valid_datas_[i],
ConstPtrInVectorWarpper<Metric>(valid_metrics_[i])); ConstPtrInVectorWarpper<Metric>(valid_metrics_[i]));
} }
Log::Info("Finish training initilization."); Log::Info("Finished initializing training");
} }
void Application::Train() { void Application::Train() {
Log::Info("Start train ..."); Log::Info("Started training...");
int total_iter = config_.boosting_config->num_iterations; int total_iter = config_.boosting_config->num_iterations;
bool is_finished = false; bool is_finished = false;
bool need_eval = true; bool need_eval = true;
...@@ -246,37 +253,38 @@ void Application::Train() { ...@@ -246,37 +253,38 @@ void Application::Train() {
is_finished = boosting_->TrainOneIter(nullptr, nullptr, need_eval); is_finished = boosting_->TrainOneIter(nullptr, nullptr, need_eval);
auto end_time = std::chrono::high_resolution_clock::now(); auto end_time = std::chrono::high_resolution_clock::now();
// output used time per iteration // output used time per iteration
Log::Info("%f seconds elapsed, finished %d iteration", std::chrono::duration<double, Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double,
std::milli>(end_time - start_time) * 1e-3, iter + 1); std::milli>(end_time - start_time) * 1e-3, iter + 1);
boosting_->SaveModelToFile(is_finished, config_.io_config.output_model.c_str()); boosting_->SaveModelToFile(is_finished, config_.io_config.output_model.c_str());
} }
is_finished = true; is_finished = true;
// save model to file // save model to file
boosting_->SaveModelToFile(is_finished, config_.io_config.output_model.c_str()); boosting_->SaveModelToFile(is_finished, config_.io_config.output_model.c_str());
Log::Info("Finished train"); Log::Info("Finished training");
} }
void Application::Predict() { void Application::Predict() {
boosting_->SetNumUsedModel(config_.io_config.num_model_predict);
// create predictor // create predictor
Predictor predictor(boosting_, config_.io_config.is_sigmoid, Predictor predictor(boosting_, config_.io_config.is_sigmoid,
config_.predict_leaf_index, config_.io_config.num_model_predict); config_.predict_leaf_index);
predictor.Predict(config_.io_config.data_filename.c_str(), predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header); config_.io_config.output_result.c_str(), config_.io_config.has_header);
Log::Info("Finish predict."); Log::Info("Finished prediction");
} }
void Application::InitPredict() { void Application::InitPredict() {
boosting_ = boosting_ =
Boosting::CreateBoosting(config_.io_config.input_model.c_str()); Boosting::CreateBoosting(config_.io_config.input_model.c_str());
Log::Info("Finish predict initilization."); Log::Info("Finished initializing prediction");
} }
template<typename T> template<typename T>
T Application::GlobalSyncUpByMin(T& local) { T Application::GlobalSyncUpByMin(T& local) {
T global = local; T global = local;
if (!config_.is_parallel) { if (!config_.is_parallel) {
// not need to sync if not parallel learning // no need to sync if not parallel learning
return global; return global;
} }
Network::Allreduce(reinterpret_cast<char*>(&local), Network::Allreduce(reinterpret_cast<char*>(&local),
......
...@@ -25,12 +25,11 @@ public: ...@@ -25,12 +25,11 @@ public:
/*! /*!
* \brief Constructor * \brief Constructor
* \param boosting Input boosting model * \param boosting Input boosting model
* \param is_sigmoid True if need to predict result with sigmoid transform(if needed, like binary classification) * \param is_sigmoid True if need to predict result with sigmoid transform (if needed, like binary classification)
* \param predict_leaf_index True if output leaf index instead of prediction score * \param predict_leaf_index True if output leaf index instead of prediction score
*/ */
Predictor(const Boosting* boosting, bool is_simgoid, bool is_predict_leaf_index, int num_used_model) Predictor(const Boosting* boosting, bool is_simgoid, bool is_predict_leaf_index)
: is_simgoid_(is_simgoid), is_predict_leaf_index_(is_predict_leaf_index), : is_simgoid_(is_simgoid), is_predict_leaf_index_(is_predict_leaf_index) {
num_used_model_(num_used_model) {
boosting_ = boosting; boosting_ = boosting;
num_features_ = boosting_->MaxFeatureIdx() + 1; num_features_ = boosting_->MaxFeatureIdx() + 1;
num_class_ = boosting_->NumberOfClass(); num_class_ = boosting_->NumberOfClass();
...@@ -57,36 +56,36 @@ public: ...@@ -57,36 +56,36 @@ public:
} }
/*! /*!
* \brief prediction for one record, only raw result(without sigmoid transformation) * \brief prediction for one record, only raw result (without sigmoid transformation)
* \param features Feature for this record * \param features Feature for this record
* \return Prediction result * \return Prediction result
*/ */
double PredictRawOneLine(const std::vector<std::pair<int, double>>& features) { std::vector<double> PredictRawOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result without sigmoid transformation // get result without sigmoid transformation
return boosting_->PredictRaw(features_[tid], num_used_model_); return std::vector<double>(1, boosting_->PredictRaw(features_[tid]));
} }
/*! /*!
* \brief prediction for one record, only raw result(without sigmoid transformation) * \brief prediction for one record, only raw result (without sigmoid transformation)
* \param features Feature for this record * \param features Feature for this record
* \return Predictied leaf index * \return Predictied leaf index
*/ */
std::vector<int> PredictLeafIndexOneLine(const std::vector<std::pair<int, double>>& features) { std::vector<int> PredictLeafIndexOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result for leaf index // get result for leaf index
return boosting_->PredictLeafIndex(features_[tid], num_used_model_); return boosting_->PredictLeafIndex(features_[tid]);
} }
/*! /*!
* \brief prediction for one record, will use sigmoid transformation if needed(only enabled for binary classification noe) * \brief prediction for one record, will use sigmoid transformation if needed (only enabled for binary classification noe)
* \param features Feature of this record * \param features Feature of this record
* \return Prediction result * \return Prediction result
*/ */
double PredictOneLine(const std::vector<std::pair<int, double>>& features) { std::vector<double> PredictOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed // get result with sigmoid transform if needed
return boosting_->Predict(features_[tid], num_used_model_); return std::vector<double>(1, boosting_->Predict(features_[tid]));
} }
/*! /*!
...@@ -97,7 +96,7 @@ public: ...@@ -97,7 +96,7 @@ public:
std::vector<double> PredictMulticlassOneLine(const std::vector<std::pair<int, double>>& features) { std::vector<double> PredictMulticlassOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed // get result with sigmoid transform if needed
return boosting_->PredictMulticlass(features_[tid], num_used_model_); return boosting_->PredictMulticlass(features_[tid]);
} }
/*! /*!
...@@ -116,12 +115,12 @@ public: ...@@ -116,12 +115,12 @@ public:
#endif #endif
if (result_file == NULL) { if (result_file == NULL) {
Log::Fatal("Predition result file %s doesn't exists", data_filename); Log::Fatal("Prediction results file %s doesn't exist", data_filename);
} }
Parser* parser = Parser::CreateParser(data_filename, has_header, num_features_, boosting_->LabelIdx()); Parser* parser = Parser::CreateParser(data_filename, has_header, num_features_, boosting_->LabelIdx());
if (parser == nullptr) { if (parser == nullptr) {
Log::Fatal("Recongnizing input data format failed, filename %s", data_filename); Log::Fatal("Could not recognize the data format of data file %s", data_filename);
} }
// function for parse data // function for parse data
...@@ -136,6 +135,7 @@ public: ...@@ -136,6 +135,7 @@ public:
if (num_class_ > 1) { if (num_class_ > 1) {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){ predict_fun = [this](const std::vector<std::pair<int, double>>& features){
std::vector<double> prediction = PredictMulticlassOneLine(features); std::vector<double> prediction = PredictMulticlassOneLine(features);
Common::Softmax(&prediction);
std::stringstream result_stream_buf; std::stringstream result_stream_buf;
for (size_t i = 0; i < prediction.size(); ++i){ for (size_t i = 0; i < prediction.size(); ++i){
if (i > 0) { if (i > 0) {
...@@ -162,12 +162,12 @@ public: ...@@ -162,12 +162,12 @@ public:
else { else {
if (is_simgoid_) { if (is_simgoid_) {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){ predict_fun = [this](const std::vector<std::pair<int, double>>& features){
return std::to_string(PredictOneLine(features)); return std::to_string(PredictOneLine(features)[0]);
}; };
} }
else { else {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){ predict_fun = [this](const std::vector<std::pair<int, double>>& features){
return std::to_string(PredictRawOneLine(features)); return std::to_string(PredictRawOneLine(features)[0]);
}; };
} }
} }
...@@ -223,8 +223,6 @@ private: ...@@ -223,8 +223,6 @@ private:
int num_threads_; int num_threads_;
/*! \brief True if output leaf index instead of prediction score */ /*! \brief True if output leaf index instead of prediction score */
bool is_predict_leaf_index_; bool is_predict_leaf_index_;
/*! \brief Number of used model */
int num_used_model_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -40,7 +40,7 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) { ...@@ -40,7 +40,7 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
} }
LoadFileToBoosting(ret, filename); LoadFileToBoosting(ret, filename);
} else { } else {
Log::Fatal("Boosting type in parameter is not same with the type in model file"); Log::Fatal("Boosting type in parameter is not the same as the type in the model file");
} }
return ret; return ret;
} }
......
...@@ -19,7 +19,8 @@ namespace LightGBM { ...@@ -19,7 +19,8 @@ namespace LightGBM {
GBDT::GBDT() GBDT::GBDT()
: train_score_updater_(nullptr), : train_score_updater_(nullptr),
gradients_(nullptr), hessians_(nullptr), gradients_(nullptr), hessians_(nullptr),
out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr) { out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr),
saved_model_size_(-1), num_used_model_(0) {
} }
GBDT::~GBDT() { GBDT::~GBDT() {
...@@ -43,6 +44,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -43,6 +44,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
const std::vector<const Metric*>& training_metrics) { const std::vector<const Metric*>& training_metrics) {
gbdt_config_ = dynamic_cast<const GBDTConfig*>(config); gbdt_config_ = dynamic_cast<const GBDTConfig*>(config);
iter_ = 0; iter_ = 0;
saved_model_size_ = -1;
max_feature_idx_ = 0; max_feature_idx_ = 0;
early_stopping_round_ = gbdt_config_->early_stopping_round; early_stopping_round_ = gbdt_config_->early_stopping_round;
train_data_ = train_data; train_data_ = train_data;
...@@ -150,7 +152,7 @@ void GBDT::Bagging(int iter, const int curr_class) { ...@@ -150,7 +152,7 @@ void GBDT::Bagging(int iter, const int curr_class) {
bag_data_cnt_ = cur_left_cnt; bag_data_cnt_ = cur_left_cnt;
out_of_bag_data_cnt_ = num_data_ - bag_data_cnt_; out_of_bag_data_cnt_ = num_data_ - bag_data_cnt_;
} }
Log::Info("re-bagging, using %d data to train", bag_data_cnt_); Log::Info("Re-bagging, using %d data to train", bag_data_cnt_);
// set bagging data to tree learner // set bagging data to tree learner
tree_learner_[curr_class]->SetBaggingData(bag_data_indices_, bag_data_cnt_); tree_learner_[curr_class]->SetBaggingData(bag_data_indices_, bag_data_cnt_);
} }
...@@ -180,7 +182,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -180,7 +182,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
Tree * new_tree = tree_learner_[curr_class]->Train(gradient + curr_class * num_data_, hessian+ curr_class * num_data_); Tree * new_tree = tree_learner_[curr_class]->Train(gradient + curr_class * num_data_, hessian+ curr_class * num_data_);
// if cannot learn a new tree, then stop // if cannot learn a new tree, then stop
if (new_tree->num_leaves() <= 1) { if (new_tree->num_leaves() <= 1) {
Log::Info("Can't training anymore, there isn't any leaf meets split requirements."); Log::Info("Stopped training because there are no more leafs that meet the split requirements.");
return true; return true;
} }
...@@ -229,7 +231,7 @@ bool GBDT::OutputMetric(int iter) { ...@@ -229,7 +231,7 @@ bool GBDT::OutputMetric(int iter) {
for (auto& sub_metric : training_metrics_) { for (auto& sub_metric : training_metrics_) {
auto name = sub_metric->GetName(); auto name = sub_metric->GetName();
auto scores = sub_metric->Eval(train_score_updater_->score()); auto scores = sub_metric->Eval(train_score_updater_->score());
Log::Info("Iteration:%d, %s : %s", iter, name, Common::ArrayToString<double>(scores, ' ').c_str()); Log::Info("Iteration: %d, %s: %s", iter, name, Common::ArrayToString<double>(scores, ' ').c_str());
} }
} }
// print validation metric // print validation metric
...@@ -239,7 +241,7 @@ bool GBDT::OutputMetric(int iter) { ...@@ -239,7 +241,7 @@ bool GBDT::OutputMetric(int iter) {
auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score()); auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score());
if ((iter % gbdt_config_->output_freq) == 0) { if ((iter % gbdt_config_->output_freq) == 0) {
auto name = valid_metrics_[i][j]->GetName(); auto name = valid_metrics_[i][j]->GetName();
Log::Info("Iteration:%d, %s : %s", iter, name, Common::ArrayToString<double>(test_scores, ' ').c_str()); Log::Info("Iteration: %d, %s: %s", iter, name, Common::ArrayToString<double>(test_scores, ' ').c_str());
} }
if (!ret && early_stopping_round_ > 0) { if (!ret && early_stopping_round_ > 0) {
bool the_bigger_the_better = valid_metrics_[i][j]->is_bigger_better(); bool the_bigger_the_better = valid_metrics_[i][j]->is_bigger_better();
...@@ -353,7 +355,7 @@ void GBDT::ModelsFromString(const std::string& model_str) { ...@@ -353,7 +355,7 @@ void GBDT::ModelsFromString(const std::string& model_str) {
std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n'); std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n');
size_t i = 0; size_t i = 0;
// get number of class // get number of classes
while (i < lines.size()) { while (i < lines.size()) {
size_t find_pos = lines[i].find("num_class="); size_t find_pos = lines[i].find("num_class=");
if (find_pos != std::string::npos) { if (find_pos != std::string::npos) {
...@@ -366,7 +368,7 @@ void GBDT::ModelsFromString(const std::string& model_str) { ...@@ -366,7 +368,7 @@ void GBDT::ModelsFromString(const std::string& model_str) {
} }
} }
if (i == lines.size()) { if (i == lines.size()) {
Log::Fatal("Model file doesn't contain number of class"); Log::Fatal("Model file doesn't specify the number of classes");
return; return;
} }
...@@ -384,7 +386,7 @@ void GBDT::ModelsFromString(const std::string& model_str) { ...@@ -384,7 +386,7 @@ void GBDT::ModelsFromString(const std::string& model_str) {
} }
} }
if (i == lines.size()) { if (i == lines.size()) {
Log::Fatal("Model file doesn't contain label index"); Log::Fatal("Model file doesn't specify the label index");
return; return;
} }
...@@ -402,7 +404,7 @@ void GBDT::ModelsFromString(const std::string& model_str) { ...@@ -402,7 +404,7 @@ void GBDT::ModelsFromString(const std::string& model_str) {
} }
} }
if (i == lines.size()) { if (i == lines.size()) {
Log::Fatal("Model file doesn't contain max_feature_idx"); Log::Fatal("Model file doesn't specify max_feature_idx");
return; return;
} }
// get sigmoid parameter // get sigmoid parameter
...@@ -437,7 +439,8 @@ void GBDT::ModelsFromString(const std::string& model_str) { ...@@ -437,7 +439,8 @@ void GBDT::ModelsFromString(const std::string& model_str) {
++i; ++i;
} }
} }
Log::Info("%d models has been loaded\n", models_.size()); Log::Info("Finished loading %d models", models_.size());
num_used_model_ = static_cast<int>(models_.size()) / num_class_;
} }
std::string GBDT::FeatureImportance() const { std::string GBDT::FeatureImportance() const {
...@@ -467,23 +470,17 @@ std::string GBDT::FeatureImportance() const { ...@@ -467,23 +470,17 @@ std::string GBDT::FeatureImportance() const {
return str_buf.str(); return str_buf.str();
} }
double GBDT::PredictRaw(const double* value, int num_used_model) const { double GBDT::PredictRaw(const double* value) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size());
}
double ret = 0.0f; double ret = 0.0f;
for (int i = 0; i < num_used_model; ++i) { for (int i = 0; i < num_used_model_; ++i) {
ret += models_[i]->Predict(value); ret += models_[i]->Predict(value);
} }
return ret; return ret;
} }
double GBDT::Predict(const double* value, int num_used_model) const { double GBDT::Predict(const double* value) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size());
}
double ret = 0.0f; double ret = 0.0f;
for (int i = 0; i < num_used_model; ++i) { for (int i = 0; i < num_used_model_; ++i) {
ret += models_[i]->Predict(value); ret += models_[i]->Predict(value);
} }
// if need sigmoid transform // if need sigmoid transform
...@@ -493,26 +490,19 @@ double GBDT::Predict(const double* value, int num_used_model) const { ...@@ -493,26 +490,19 @@ double GBDT::Predict(const double* value, int num_used_model) const {
return ret; return ret;
} }
std::vector<double> GBDT::PredictMulticlass(const double* value, int num_used_model) const { std::vector<double> GBDT::PredictMulticlass(const double* value) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size()) / num_class_;
}
std::vector<double> ret(num_class_, 0.0f); std::vector<double> ret(num_class_, 0.0f);
for (int i = 0; i < num_used_model; ++i) { for (int i = 0; i < num_used_model_; ++i) {
for (int j = 0; j < num_class_; ++j){ for (int j = 0; j < num_class_; ++j){
ret[j] += models_[i * num_class_ + j] -> Predict(value); ret[j] += models_[i * num_class_ + j] -> Predict(value);
} }
} }
Common::Softmax(&ret);
return ret; return ret;
} }
std::vector<int> GBDT::PredictLeafIndex(const double* value, int num_used_model) const { std::vector<int> GBDT::PredictLeafIndex(const double* value) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size());
}
std::vector<int> ret; std::vector<int> ret;
for (int i = 0; i < num_used_model; ++i) { for (int i = 0; i < num_used_model_; ++i) {
ret.push_back(models_[i]->PredictLeafIndex(value)); ret.push_back(models_[i]->PredictLeafIndex(value));
} }
return ret; return ret;
......
...@@ -55,33 +55,30 @@ public: ...@@ -55,33 +55,30 @@ public:
/*! /*!
* \brief Predtion for one record without sigmoid transformation * \brief Predtion for one record without sigmoid transformation
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record * \return Prediction result for this record
*/ */
double PredictRaw(const double* feature_values, int num_used_model) const override; double PredictRaw(const double* feature_values) const override;
/*! /*!
* \brief Predtion for one record with sigmoid transformation if enabled * \brief Predtion for one record with sigmoid transformation if enabled
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record * \return Prediction result for this record
*/ */
double Predict(const double* feature_values, int num_used_model) const override; double Predict(const double* feature_values) const override;
/*! /*!
* \brief Predtion for multiclass classification * \brief Predtion for multiclass classification
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result, num_class numbers per line * \return Prediction result, num_class numbers per line
*/ */
std::vector<double> PredictMulticlass(const double* value, int num_used_model) const override; std::vector<double> PredictMulticlass(const double* value) const override;
/*! /*!
* \brief Predtion for one record with leaf index * \brief Predtion for one record with leaf index
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Predicted leaf index for this record * \return Predicted leaf index for this record
*/ */
std::vector<int> PredictLeafIndex(const double* value, int num_used_model) const override; std::vector<int> PredictLeafIndex(const double* value) const override;
/*! /*!
* \brief Serialize models by string * \brief Serialize models by string
...@@ -116,6 +113,16 @@ public: ...@@ -116,6 +113,16 @@ public:
*/ */
inline int NumberOfClass() const override { return num_class_; } inline int NumberOfClass() const override { return num_class_; }
/*!
* \brief Set number of used model for prediction
*/
inline void SetNumUsedModel(int num_used_model) {
if (num_used_model >= 0) {
num_used_model_ = static_cast<int>(num_used_model / num_class_);
}
}
/*! /*!
* \brief Get Type name of this boosting object * \brief Get Type name of this boosting object
*/ */
...@@ -208,9 +215,11 @@ private: ...@@ -208,9 +215,11 @@ private:
/*! \brief Index of label column */ /*! \brief Index of label column */
data_size_t label_idx_; data_size_t label_idx_;
/*! \brief Saved number of models */ /*! \brief Saved number of models */
int saved_model_size_ = -1; int saved_model_size_;
/*! \brief File to write models */ /*! \brief File to write models */
std::ofstream model_output_file_; std::ofstream model_output_file_;
/*! \brief number of used model */
int num_used_model_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -27,7 +27,7 @@ public: ...@@ -27,7 +27,7 @@ public:
const float* init_score = data->metadata().init_score(); const float* init_score = data->metadata().init_score();
// if exists initial score, will start from it // if exists initial score, will start from it
if (init_score != nullptr) { if (init_score != nullptr) {
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_ * num_class; ++i) {
score_[i] = init_score[i]; score_[i] = init_score[i];
} }
} }
......
...@@ -77,7 +77,7 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s ...@@ -77,7 +77,7 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s
if (value == std::string("gbdt") || value == std::string("gbrt")) { if (value == std::string("gbdt") || value == std::string("gbrt")) {
boosting_type = BoostingType::kGBDT; boosting_type = BoostingType::kGBDT;
} else { } else {
Log::Fatal("Boosting type %s error", value.c_str()); Log::Fatal("Unknown boosting type %s", value.c_str());
} }
} }
} }
...@@ -125,7 +125,7 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin ...@@ -125,7 +125,7 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
|| value == std::string("test")) { || value == std::string("test")) {
task_type = TaskType::kPredict; task_type = TaskType::kPredict;
} else { } else {
Log::Fatal("Task type error"); Log::Fatal("Unknown task type %s", value.c_str());
} }
} }
} }
...@@ -138,19 +138,19 @@ void OverallConfig::CheckParamConflict() { ...@@ -138,19 +138,19 @@ void OverallConfig::CheckParamConflict() {
int num_class_check = gbdt_config->num_class; int num_class_check = gbdt_config->num_class;
if (objective_type_multiclass){ if (objective_type_multiclass){
if (num_class_check <= 1){ if (num_class_check <= 1){
Log::Fatal("You should specify number of class(>=2) for multiclass training."); Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
} }
} }
else { else {
if (task_type == TaskType::kTrain && num_class_check != 1){ if (task_type == TaskType::kTrain && num_class_check != 1){
Log::Fatal("Number of class must be 1 for non-multiclass training."); Log::Fatal("Number of classes must be 1 for non-multiclass training");
} }
} }
for (std::string metric_type : metric_types){ for (std::string metric_type : metric_types){
bool metric_type_multiclass = ( metric_type == std::string("multi_logloss") || metric_type == std::string("multi_error")); bool metric_type_multiclass = ( metric_type == std::string("multi_logloss") || metric_type == std::string("multi_error"));
if ((objective_type_multiclass && !metric_type_multiclass) if ((objective_type_multiclass && !metric_type_multiclass)
|| (!objective_type_multiclass && metric_type_multiclass)){ || (!objective_type_multiclass && metric_type_multiclass)){
Log::Fatal("Objective and metrics don't match."); Log::Fatal("Objective and metrics don't match");
} }
} }
...@@ -172,9 +172,9 @@ void OverallConfig::CheckParamConflict() { ...@@ -172,9 +172,9 @@ void OverallConfig::CheckParamConflict() {
} else if (gbdt_config->tree_learner_type == TreeLearnerType::kDataParallelTreeLearner) { } else if (gbdt_config->tree_learner_type == TreeLearnerType::kDataParallelTreeLearner) {
is_parallel_find_bin = true; is_parallel_find_bin = true;
if (gbdt_config->tree_config.histogram_pool_size >= 0) { if (gbdt_config->tree_config.histogram_pool_size >= 0) {
Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f). Will disable this for reducing communication cost." Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f). Will disable this to reduce communication costs"
, gbdt_config->tree_config.histogram_pool_size); , gbdt_config->tree_config.histogram_pool_size);
// Change pool size to -1(not limit) when using data parallel for reducing communication cost // Change pool size to -1 (not limit) when using data parallel to reduce communication costs
gbdt_config->tree_config.histogram_pool_size = -1; gbdt_config->tree_config.histogram_pool_size = -1;
} }
...@@ -184,6 +184,7 @@ void OverallConfig::CheckParamConflict() { ...@@ -184,6 +184,7 @@ void OverallConfig::CheckParamConflict() {
void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "max_bin", &max_bin); GetInt(params, "max_bin", &max_bin);
CHECK(max_bin > 0); CHECK(max_bin > 0);
GetInt(params, "num_class", &num_class);
GetInt(params, "data_random_seed", &data_random_seed); GetInt(params, "data_random_seed", &data_random_seed);
if (!GetString(params, "data", &data_filename)) { if (!GetString(params, "data", &data_filename)) {
...@@ -236,7 +237,6 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa ...@@ -236,7 +237,6 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa
void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) { void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetDouble(params, "sigmoid", &sigmoid); GetDouble(params, "sigmoid", &sigmoid);
GetInt(params, "num_class", &num_class); GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1);
std::string tmp_str = ""; std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) { if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToDoubleArray(tmp_str, ','); label_gain = Common::StringToDoubleArray(tmp_str, ',');
...@@ -294,7 +294,6 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par ...@@ -294,7 +294,6 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
CHECK(output_freq >= 0); CHECK(output_freq >= 0);
GetBool(params, "is_training_metric", &is_provide_training_metric); GetBool(params, "is_training_metric", &is_provide_training_metric);
GetInt(params, "num_class", &num_class); GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1);
} }
void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) { void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
...@@ -309,7 +308,7 @@ void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::s ...@@ -309,7 +308,7 @@ void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::s
tree_learner_type = TreeLearnerType::kDataParallelTreeLearner; tree_learner_type = TreeLearnerType::kDataParallelTreeLearner;
} }
else { else {
Log::Fatal("Tree learner type error"); Log::Fatal("Unknown tree learner type %s", value.c_str());
} }
} }
} }
......
...@@ -20,15 +20,17 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -20,15 +20,17 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
:data_filename_(data_filename), random_(io_config.data_random_seed), :data_filename_(data_filename), random_(io_config.data_random_seed),
max_bin_(io_config.max_bin), is_enable_sparse_(io_config.is_enable_sparse), predict_fun_(predict_fun) { max_bin_(io_config.max_bin), is_enable_sparse_(io_config.is_enable_sparse), predict_fun_(predict_fun) {
num_class_ = io_config.num_class;
CheckCanLoadFromBin(); CheckCanLoadFromBin();
if (is_loading_from_binfile_ && predict_fun != nullptr) { if (is_loading_from_binfile_ && predict_fun != nullptr) {
Log::Info("Cannot performing initialization of prediction by using binary file, using text file instead"); Log::Info("Cannot initialize prediction by using a binary file, using text file instead");
is_loading_from_binfile_ = false; is_loading_from_binfile_ = false;
} }
if (!is_loading_from_binfile_) { if (!is_loading_from_binfile_) {
// load weight, query information and initilize score // load weight, query information and initialize score
metadata_.Init(data_filename, init_score_filename); metadata_.Init(data_filename, init_score_filename, num_class_);
// create text reader // create text reader
text_reader_ = new TextReader<data_size_t>(data_filename, io_config.has_header); text_reader_ = new TextReader<data_size_t>(data_filename, io_config.has_header);
...@@ -49,17 +51,17 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -49,17 +51,17 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
std::string name = io_config.label_column.substr(name_prefix.size()); std::string name = io_config.label_column.substr(name_prefix.size());
if (name2idx.count(name) > 0) { if (name2idx.count(name) > 0) {
label_idx_ = name2idx[name]; label_idx_ = name2idx[name];
Log::Info("use %s column as label", name.c_str()); Log::Info("Using column %s as label", name.c_str());
} else { } else {
Log::Fatal("cannot find label column: %s in data file", name.c_str()); Log::Fatal("Could not find label column %s in data file", name.c_str());
} }
} else { } else {
if (!Common::AtoiAndCheck(io_config.label_column.c_str(), &label_idx_)) { if (!Common::AtoiAndCheck(io_config.label_column.c_str(), &label_idx_)) {
Log::Fatal("label_column is not a number, \ Log::Fatal("label_column is not a number, \
if you want to use column name, \ if you want to use a column name, \
please add prefix \"name:\" before column name"); please add the prefix \"name:\" to the column name");
} }
Log::Info("use %d-th column as label", label_idx_); Log::Info("Using column number %d as label", label_idx_);
} }
} }
if (feature_names_.size() > 0) { if (feature_names_.size() > 0) {
...@@ -77,7 +79,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -77,7 +79,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
if (tmp > label_idx_) { tmp -= 1; } if (tmp > label_idx_) { tmp -= 1; }
ignore_features_.emplace(tmp); ignore_features_.emplace(tmp);
} else { } else {
Log::Fatal("cannot find column: %s in data file", name.c_str()); Log::Fatal("Could not find ignore column %s in data file", name.c_str());
} }
} }
} else { } else {
...@@ -85,8 +87,8 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -85,8 +87,8 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
int tmp = 0; int tmp = 0;
if (!Common::AtoiAndCheck(token.c_str(), &tmp)) { if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Log::Fatal("ignore_column is not a number, \ Log::Fatal("ignore_column is not a number, \
if you want to use column name, \ if you want to use a column name, \
please add prefix \"name:\" before column name"); please add the prefix \"name:\" to the column name");
} }
// skip for label column // skip for label column
if (tmp > label_idx_) { tmp -= 1; } if (tmp > label_idx_) { tmp -= 1; }
...@@ -102,17 +104,17 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -102,17 +104,17 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
std::string name = io_config.weight_column.substr(name_prefix.size()); std::string name = io_config.weight_column.substr(name_prefix.size());
if (name2idx.count(name) > 0) { if (name2idx.count(name) > 0) {
weight_idx_ = name2idx[name]; weight_idx_ = name2idx[name];
Log::Info("use %s column as weight", name.c_str()); Log::Info("Using column %s as weight", name.c_str());
} else { } else {
Log::Fatal("cannot find weight column: %s in data file", name.c_str()); Log::Fatal("Could not find weight column %s in data file", name.c_str());
} }
} else { } else {
if (!Common::AtoiAndCheck(io_config.weight_column.c_str(), &weight_idx_)) { if (!Common::AtoiAndCheck(io_config.weight_column.c_str(), &weight_idx_)) {
Log::Fatal("weight_column is not a number, \ Log::Fatal("weight_column is not a number, \
if you want to use column name, \ if you want to use a column name, \
please add prefix \"name:\" before column name"); please add the prefix \"name:\" to the column name");
} }
Log::Info("use %d-th column as weight", weight_idx_); Log::Info("Using column number %d as weight", weight_idx_);
} }
// skip for label column // skip for label column
if (weight_idx_ > label_idx_) { if (weight_idx_ > label_idx_) {
...@@ -126,17 +128,17 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -126,17 +128,17 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
std::string name = io_config.group_column.substr(name_prefix.size()); std::string name = io_config.group_column.substr(name_prefix.size());
if (name2idx.count(name) > 0) { if (name2idx.count(name) > 0) {
group_idx_ = name2idx[name]; group_idx_ = name2idx[name];
Log::Info("use %s column as group/query id", name.c_str()); Log::Info("Using column %s as group/query id", name.c_str());
} else { } else {
Log::Fatal("cannot find group/query column: %s in data file", name.c_str()); Log::Fatal("Could not find group/query column %s in data file", name.c_str());
} }
} else { } else {
if (!Common::AtoiAndCheck(io_config.group_column.c_str(), &group_idx_)) { if (!Common::AtoiAndCheck(io_config.group_column.c_str(), &group_idx_)) {
Log::Fatal("group_column is not a number, \ Log::Fatal("group_column is not a number, \
if you want to use column name, \ if you want to use a column name, \
please add prefix \"name:\" before column name"); please add the prefix \"name:\" to the column name");
} }
Log::Info("use %d-th column as group/query id", group_idx_); Log::Info("Using column number %d as group/query id", group_idx_);
} }
// skip for label column // skip for label column
if (group_idx_ > label_idx_) { if (group_idx_ > label_idx_) {
...@@ -148,11 +150,11 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -148,11 +150,11 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
// create text parser // create text parser
parser_ = Parser::CreateParser(data_filename_, io_config.has_header, 0, label_idx_); parser_ = Parser::CreateParser(data_filename_, io_config.has_header, 0, label_idx_);
if (parser_ == nullptr) { if (parser_ == nullptr) {
Log::Fatal("Cannot recognising input data format, filename: %s", data_filename_); Log::Fatal("Could not recognize data format of %s", data_filename_);
} }
} else { } else {
// only need to load initilize score, other meta data will be loaded from bin flie // only need to load initialize score, other meta data will be loaded from binary file
metadata_.Init(init_score_filename); metadata_.Init(init_score_filename, num_class_);
Log::Info("Loading data set from binary file"); Log::Info("Loading data set from binary file");
parser_ = nullptr; parser_ = nullptr;
text_reader_ = nullptr; text_reader_ = nullptr;
...@@ -255,8 +257,8 @@ void Dataset::SampleDataFromFile(int rank, int num_machines, bool is_pre_partiti ...@@ -255,8 +257,8 @@ void Dataset::SampleDataFromFile(int rank, int num_machines, bool is_pre_partiti
[this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries] [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
(data_size_t line_idx) { (data_size_t line_idx) {
if (qid >= num_queries) { if (qid >= num_queries) {
Log::Fatal("Query id is exceed the range of query file, \ Log::Fatal("Query id exceeds the range of the query file, \
please ensure your query file is correct"); please ensure the query file is correct");
} }
if (line_idx >= query_boundaries[qid + 1]) { if (line_idx >= query_boundaries[qid + 1]) {
// if is new query // if is new query
...@@ -324,7 +326,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -324,7 +326,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
// start find bins // start find bins
if (num_machines == 1) { if (num_machines == 1) {
std::vector<BinMapper*> bin_mappers(sample_values.size()); std::vector<BinMapper*> bin_mappers(sample_values.size());
// if only 1 machines, find bin locally // if only one machine, find bin locally
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) { for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
if (ignore_features_.count(i) > 0) { if (ignore_features_.count(i) > 0) {
...@@ -337,7 +339,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -337,7 +339,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
for (size_t i = 0; i < sample_values.size(); ++i) { for (size_t i = 0; i < sample_values.size(); ++i) {
if (bin_mappers[i] == nullptr) { if (bin_mappers[i] == nullptr) {
Log::Warning("Ignore Feature %s ", feature_names_[i].c_str()); Log::Warning("Ignoring feature %s", feature_names_[i].c_str());
} }
else if (!bin_mappers[i]->is_trival()) { else if (!bin_mappers[i]->is_trival()) {
// map real feature index to used feature index // map real feature index to used feature index
...@@ -347,7 +349,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -347,7 +349,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
num_data_, is_enable_sparse_)); num_data_, is_enable_sparse_));
} else { } else {
// if feature is trival(only 1 bin), free spaces // if feature is trival(only 1 bin), free spaces
Log::Warning("Feature %s only contains one value, will be ignored", feature_names_[i].c_str()); Log::Warning("Ignoring feature %s, only has one value", feature_names_[i].c_str());
delete bin_mappers[i]; delete bin_mappers[i];
} }
} }
...@@ -395,7 +397,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -395,7 +397,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
// restore features bins from buffer // restore features bins from buffer
for (int i = 0; i < total_num_feature; ++i) { for (int i = 0; i < total_num_feature; ++i) {
if (ignore_features_.count(i) > 0) { if (ignore_features_.count(i) > 0) {
Log::Warning("Ignore Feature %s ", feature_names_[i].c_str()); Log::Warning("Ignoring feature %s", feature_names_[i].c_str());
continue; continue;
} }
BinMapper* bin_mapper = new BinMapper(); BinMapper* bin_mapper = new BinMapper();
...@@ -404,7 +406,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -404,7 +406,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
used_feature_map_[i] = static_cast<int>(features_.size()); used_feature_map_[i] = static_cast<int>(features_.size());
features_.push_back(new Feature(static_cast<int>(i), bin_mapper, num_data_, is_enable_sparse_)); features_.push_back(new Feature(static_cast<int>(i), bin_mapper, num_data_, is_enable_sparse_));
} else { } else {
Log::Warning("Feature %s only contains one value, will be ignored", feature_names_[i].c_str()); Log::Warning("Ignoring feature %s, only has one value", feature_names_[i].c_str());
delete bin_mapper; delete bin_mapper;
} }
} }
...@@ -422,8 +424,8 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b ...@@ -422,8 +424,8 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b
// don't support query id in data file when training in parallel // don't support query id in data file when training in parallel
if (num_machines > 1 && !is_pre_partition) { if (num_machines > 1 && !is_pre_partition) {
if (group_idx_ > 0) { if (group_idx_ > 0) {
Log::Fatal("Don't support query id in data file when training parallel without pre-partition. \ Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training. \
Please use an additional query file or pre-partition your data"); Please use an additional query file or pre-partition the data");
} }
} }
used_data_indices_.clear(); used_data_indices_.clear();
...@@ -437,7 +439,7 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b ...@@ -437,7 +439,7 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b
// construct feature bin mappers // construct feature bin mappers
ConstructBinMappers(rank, num_machines, sample_data); ConstructBinMappers(rank, num_machines, sample_data);
// initialize label // initialize label
metadata_.Init(num_data_, weight_idx_, group_idx_); metadata_.Init(num_data_, num_class_, weight_idx_, group_idx_);
// extract features // extract features
ExtractFeaturesFromMemory(); ExtractFeaturesFromMemory();
} else { } else {
...@@ -447,7 +449,7 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b ...@@ -447,7 +449,7 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b
// construct feature bin mappers // construct feature bin mappers
ConstructBinMappers(rank, num_machines, sample_data); ConstructBinMappers(rank, num_machines, sample_data);
// initialize label // initialize label
metadata_.Init(num_data_, weight_idx_, group_idx_); metadata_.Init(num_data_, num_class_, weight_idx_, group_idx_);
// extract features // extract features
ExtractFeaturesFromFile(); ExtractFeaturesFromFile();
...@@ -472,7 +474,7 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo ...@@ -472,7 +474,7 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo
// read data in memory // read data in memory
LoadDataToMemory(0, 1, false); LoadDataToMemory(0, 1, false);
// initialize label // initialize label
metadata_.Init(num_data_, weight_idx_, group_idx_); metadata_.Init(num_data_, num_class_, weight_idx_, group_idx_);
features_.clear(); features_.clear();
// copy feature bin mapper data // copy feature bin mapper data
for (Feature* feature : train_set->features_) { for (Feature* feature : train_set->features_) {
...@@ -488,7 +490,7 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo ...@@ -488,7 +490,7 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo
// Get number of lines of data file // Get number of lines of data file
num_data_ = static_cast<data_size_t>(text_reader_->CountLine()); num_data_ = static_cast<data_size_t>(text_reader_->CountLine());
// initialize label // initialize label
metadata_.Init(num_data_, weight_idx_, group_idx_); metadata_.Init(num_data_, num_class_, weight_idx_, group_idx_);
features_.clear(); features_.clear();
// copy feature bin mapper data // copy feature bin mapper data
for (Feature* feature : train_set->features_) { for (Feature* feature : train_set->features_) {
...@@ -539,14 +541,14 @@ void Dataset::ExtractFeaturesFromMemory() { ...@@ -539,14 +541,14 @@ void Dataset::ExtractFeaturesFromMemory() {
if (inner_data.first == weight_idx_) { if (inner_data.first == weight_idx_) {
metadata_.SetWeightAt(i, static_cast<float>(inner_data.second)); metadata_.SetWeightAt(i, static_cast<float>(inner_data.second));
} else if (inner_data.first == group_idx_) { } else if (inner_data.first == group_idx_) {
metadata_.SetQueryAt(i, static_cast<float>(inner_data.second)); metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
} }
} }
} }
} }
} else { } else {
// if need to prediction with initial model // if need to prediction with initial model
float* init_score = new float[num_data_]; float* init_score = new float[num_data_ * num_class_];
#pragma omp parallel for schedule(guided) private(oneline_features) firstprivate(tmp_label) #pragma omp parallel for schedule(guided) private(oneline_features) firstprivate(tmp_label)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
...@@ -554,7 +556,10 @@ void Dataset::ExtractFeaturesFromMemory() { ...@@ -554,7 +556,10 @@ void Dataset::ExtractFeaturesFromMemory() {
// parser // parser
parser_->ParseOneLine(text_reader_->Lines()[i].c_str(), &oneline_features, &tmp_label); parser_->ParseOneLine(text_reader_->Lines()[i].c_str(), &oneline_features, &tmp_label);
// set initial score // set initial score
init_score[i] = static_cast<float>(predict_fun_(oneline_features)); std::vector<double> oneline_init_score = predict_fun_(oneline_features);
for (int k = 0; k < num_class_; ++k){
init_score[k * num_data_ + i] = static_cast<float>(oneline_init_score[k]);
}
// set label // set label
metadata_.SetLabelAt(i, static_cast<float>(tmp_label)); metadata_.SetLabelAt(i, static_cast<float>(tmp_label));
// free processed line: // free processed line:
...@@ -572,13 +577,13 @@ void Dataset::ExtractFeaturesFromMemory() { ...@@ -572,13 +577,13 @@ void Dataset::ExtractFeaturesFromMemory() {
if (inner_data.first == weight_idx_) { if (inner_data.first == weight_idx_) {
metadata_.SetWeightAt(i, static_cast<float>(inner_data.second)); metadata_.SetWeightAt(i, static_cast<float>(inner_data.second));
} else if (inner_data.first == group_idx_) { } else if (inner_data.first == group_idx_) {
metadata_.SetQueryAt(i, static_cast<float>(inner_data.second)); metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
} }
} }
} }
} }
// metadata_ will manage space of init_score // metadata_ will manage space of init_score
metadata_.SetInitScore(init_score, num_data_); metadata_.SetInitScore(init_score, num_data_ * num_class_);
delete[] init_score; delete[] init_score;
} }
...@@ -594,7 +599,7 @@ void Dataset::ExtractFeaturesFromMemory() { ...@@ -594,7 +599,7 @@ void Dataset::ExtractFeaturesFromMemory() {
void Dataset::ExtractFeaturesFromFile() { void Dataset::ExtractFeaturesFromFile() {
float* init_score = nullptr; float* init_score = nullptr;
if (predict_fun_ != nullptr) { if (predict_fun_ != nullptr) {
init_score = new float[num_data_]; init_score = new float[num_data_ * num_class_];
} }
std::function<void(data_size_t, const std::vector<std::string>&)> process_fun = std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
[this, &init_score] [this, &init_score]
...@@ -609,7 +614,10 @@ void Dataset::ExtractFeaturesFromFile() { ...@@ -609,7 +614,10 @@ void Dataset::ExtractFeaturesFromFile() {
parser_->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label); parser_->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
// set initial score // set initial score
if (init_score != nullptr) { if (init_score != nullptr) {
init_score[start_idx + i] = static_cast<float>(predict_fun_(oneline_features)); std::vector<double> oneline_init_score = predict_fun_(oneline_features);
for (int k = 0; k < num_class_; ++k){
init_score[k * num_data_ + start_idx + i] = static_cast<float>(oneline_init_score[k]);
}
} }
// set label // set label
metadata_.SetLabelAt(start_idx + i, static_cast<float>(tmp_label)); metadata_.SetLabelAt(start_idx + i, static_cast<float>(tmp_label));
...@@ -624,7 +632,7 @@ void Dataset::ExtractFeaturesFromFile() { ...@@ -624,7 +632,7 @@ void Dataset::ExtractFeaturesFromFile() {
if (inner_data.first == weight_idx_) { if (inner_data.first == weight_idx_) {
metadata_.SetWeightAt(start_idx + i, static_cast<float>(inner_data.second)); metadata_.SetWeightAt(start_idx + i, static_cast<float>(inner_data.second));
} else if (inner_data.first == group_idx_) { } else if (inner_data.first == group_idx_) {
metadata_.SetQueryAt(start_idx + i, static_cast<float>(inner_data.second)); metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
} }
} }
} }
...@@ -641,7 +649,7 @@ void Dataset::ExtractFeaturesFromFile() { ...@@ -641,7 +649,7 @@ void Dataset::ExtractFeaturesFromFile() {
// metadata_ will manage space of init_score // metadata_ will manage space of init_score
if (init_score != nullptr) { if (init_score != nullptr) {
metadata_.SetInitScore(init_score, num_data_); metadata_.SetInitScore(init_score, num_data_ * num_class_);
delete[] init_score; delete[] init_score;
} }
...@@ -663,10 +671,10 @@ void Dataset::SaveBinaryFile() { ...@@ -663,10 +671,10 @@ void Dataset::SaveBinaryFile() {
file = fopen(bin_filename.c_str(), "wb"); file = fopen(bin_filename.c_str(), "wb");
#endif #endif
if (file == NULL) { if (file == NULL) {
Log::Fatal("Cannot write binary data to %s ", bin_filename.c_str()); Log::Fatal("Could not write binary data to %s", bin_filename.c_str());
} }
Log::Info("Saving data to binary file: %s", data_filename_); Log::Info("Saving data to binary file %s", data_filename_);
// get size of header // get size of header
size_t size_of_header = sizeof(global_num_data_) + sizeof(is_enable_sparse_) size_t size_of_header = sizeof(global_num_data_) + sizeof(is_enable_sparse_)
...@@ -746,7 +754,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -746,7 +754,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
#endif #endif
if (file == NULL) { if (file == NULL) {
Log::Fatal("Cannot read binary data from %s", bin_filename.c_str()); Log::Fatal("Could not read binary data from %s", bin_filename.c_str());
} }
// buffer to read binary file // buffer to read binary file
...@@ -757,7 +765,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -757,7 +765,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
size_t read_cnt = fread(buffer, sizeof(size_t), 1, file); size_t read_cnt = fread(buffer, sizeof(size_t), 1, file);
if (read_cnt != 1) { if (read_cnt != 1) {
Log::Fatal("Binary file format error at header size"); Log::Fatal("Binary file error: header has the wrong size");
} }
size_t size_of_head = *(reinterpret_cast<size_t*>(buffer)); size_t size_of_head = *(reinterpret_cast<size_t*>(buffer));
...@@ -772,7 +780,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -772,7 +780,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
read_cnt = fread(buffer, 1, size_of_head, file); read_cnt = fread(buffer, 1, size_of_head, file);
if (read_cnt != size_of_head) { if (read_cnt != size_of_head) {
Log::Fatal("Binary file format error at header"); Log::Fatal("Binary file error: header is incorrect");
} }
// get header // get header
const char* mem_ptr = buffer; const char* mem_ptr = buffer;
...@@ -815,7 +823,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -815,7 +823,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
read_cnt = fread(buffer, sizeof(size_t), 1, file); read_cnt = fread(buffer, sizeof(size_t), 1, file);
if (read_cnt != 1) { if (read_cnt != 1) {
Log::Fatal("Binary file format error: wrong size of meta data"); Log::Fatal("Binary file error: meta data has the wrong size");
} }
size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer)); size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer));
...@@ -830,7 +838,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -830,7 +838,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
read_cnt = fread(buffer, 1, size_of_metadata, file); read_cnt = fread(buffer, 1, size_of_metadata, file);
if (read_cnt != size_of_metadata) { if (read_cnt != size_of_metadata) {
Log::Fatal("Binary file format error: wrong size of meta data"); Log::Fatal("Binary file error: meta data is incorrect");
} }
// load meta data // load meta data
metadata_.LoadFromMemory(buffer); metadata_.LoadFromMemory(buffer);
...@@ -854,7 +862,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -854,7 +862,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
bool is_query_used = false; bool is_query_used = false;
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
if (qid >= num_queries) { if (qid >= num_queries) {
Log::Fatal("current query is exceed the range of query file, please ensure your query file is correct"); Log::Fatal("Current query exceeds the range of the query file, please ensure the query file is correct");
} }
if (i >= query_boundaries[qid + 1]) { if (i >= query_boundaries[qid + 1]) {
// if is new query // if is new query
...@@ -877,7 +885,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -877,7 +885,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
// read feature size // read feature size
read_cnt = fread(buffer, sizeof(size_t), 1, file); read_cnt = fread(buffer, sizeof(size_t), 1, file);
if (read_cnt != 1) { if (read_cnt != 1) {
Log::Fatal("Binary file format error at feature %d's size", i); Log::Fatal("Binary file error: feature %d has the wrong size", i);
} }
size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer)); size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer));
// re-allocate space if not enough // re-allocate space if not enough
...@@ -890,7 +898,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -890,7 +898,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
read_cnt = fread(buffer, 1, size_of_feature, file); read_cnt = fread(buffer, 1, size_of_feature, file);
if (read_cnt != size_of_feature) { if (read_cnt != size_of_feature) {
Log::Fatal("Binary file format error at feature %d loading , read count %d", i, read_cnt); Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
} }
features_.push_back(new Feature(buffer, static_cast<data_size_t>(global_num_data_), used_data_indices_)); features_.push_back(new Feature(buffer, static_cast<data_size_t>(global_num_data_), used_data_indices_));
} }
...@@ -903,7 +911,7 @@ void Dataset::CheckDataset() { ...@@ -903,7 +911,7 @@ void Dataset::CheckDataset() {
Log::Fatal("Data file %s is empty", data_filename_); Log::Fatal("Data file %s is empty", data_filename_);
} }
if (features_.size() <= 0) { if (features_.size() <= 0) {
Log::Fatal("Usable feature of data %s is null", data_filename_); Log::Fatal("No usable features in data file %s", data_filename_);
} }
} }
......
...@@ -14,9 +14,10 @@ Metadata::Metadata() ...@@ -14,9 +14,10 @@ Metadata::Metadata()
} }
void Metadata::Init(const char * data_filename, const char* init_score_filename) { void Metadata::Init(const char * data_filename, const char* init_score_filename, const int num_class) {
data_filename_ = data_filename; data_filename_ = data_filename;
init_score_filename_ = init_score_filename; init_score_filename_ = init_score_filename;
num_class_ = num_class;
// for lambdarank, it needs query data for partition data in parallel learning // for lambdarank, it needs query data for partition data in parallel learning
LoadQueryBoundaries(); LoadQueryBoundaries();
LoadWeights(); LoadWeights();
...@@ -24,8 +25,9 @@ void Metadata::Init(const char * data_filename, const char* init_score_filename) ...@@ -24,8 +25,9 @@ void Metadata::Init(const char * data_filename, const char* init_score_filename)
LoadInitialScore(); LoadInitialScore();
} }
void Metadata::Init(const char* init_score_filename) { void Metadata::Init(const char* init_score_filename, const int num_class) {
init_score_filename_ = init_score_filename; init_score_filename_ = init_score_filename;
num_class_ = num_class;
LoadInitialScore(); LoadInitialScore();
} }
...@@ -40,12 +42,13 @@ Metadata::~Metadata() { ...@@ -40,12 +42,13 @@ Metadata::~Metadata() {
} }
void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) { void Metadata::Init(data_size_t num_data, int num_class, int weight_idx, int query_idx) {
num_data_ = num_data; num_data_ = num_data;
num_class_ = num_class;
label_ = new float[num_data_]; label_ = new float[num_data_];
if (weight_idx >= 0) { if (weight_idx >= 0) {
if (weights_ != nullptr) { if (weights_ != nullptr) {
Log::Info("using weight in data file, and ignore additional weight file"); Log::Info("Using weights in data file, ignoring the additional weights file");
delete[] weights_; delete[] weights_;
} }
weights_ = new float[num_data_]; weights_ = new float[num_data_];
...@@ -54,7 +57,7 @@ void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) { ...@@ -54,7 +57,7 @@ void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
} }
if (query_idx >= 0) { if (query_idx >= 0) {
if (query_boundaries_ != nullptr) { if (query_boundaries_ != nullptr) {
Log::Info("using query id in data file, and ignore additional query file"); Log::Info("Using query id in data file, ignoring the additional query file");
delete[] query_boundaries_; delete[] query_boundaries_;
} }
if (query_weights_ != nullptr) { delete[] query_weights_; } if (query_weights_ != nullptr) { delete[] query_weights_; }
...@@ -106,7 +109,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -106,7 +109,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
} }
// check weights // check weights
if (weights_ != nullptr && num_weights_ != num_data_) { if (weights_ != nullptr && num_weights_ != num_data_) {
Log::Fatal("Initial weight size doesn't equal to data"); Log::Fatal("Weights size doesn't match data size");
delete[] weights_; delete[] weights_;
num_weights_ = 0; num_weights_ = 0;
weights_ = nullptr; weights_ = nullptr;
...@@ -114,7 +117,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -114,7 +117,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// check query boundries // check query boundries
if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_data_) { if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_data_) {
Log::Fatal("Initial query size doesn't equal to data"); Log::Fatal("Query size doesn't match data size");
delete[] query_boundaries_; delete[] query_boundaries_;
num_queries_ = 0; num_queries_ = 0;
query_boundaries_ = nullptr; query_boundaries_ = nullptr;
...@@ -123,7 +126,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -123,7 +126,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// contain initial score file // contain initial score file
if (init_score_ != nullptr && num_init_score_ != num_data_) { if (init_score_ != nullptr && num_init_score_ != num_data_) {
delete[] init_score_; delete[] init_score_;
Log::Fatal("Initial score size doesn't equal to data"); Log::Fatal("Initial score size doesn't match data size");
init_score_ = nullptr; init_score_ = nullptr;
num_init_score_ = 0; num_init_score_ = 0;
} }
...@@ -131,14 +134,14 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -131,14 +134,14 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size()); data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
// check weights // check weights
if (weights_ != nullptr && num_weights_ != num_all_data) { if (weights_ != nullptr && num_weights_ != num_all_data) {
Log::Fatal("Initial weights size doesn't equal to data"); Log::Fatal("Weights size doesn't match data size");
delete[] weights_; delete[] weights_;
num_weights_ = 0; num_weights_ = 0;
weights_ = nullptr; weights_ = nullptr;
} }
// check query boundries // check query boundries
if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_all_data) { if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_all_data) {
Log::Fatal("Initial query size doesn't equal to data"); Log::Fatal("Query size doesn't match data size");
delete[] query_boundaries_; delete[] query_boundaries_;
num_queries_ = 0; num_queries_ = 0;
query_boundaries_ = nullptr; query_boundaries_ = nullptr;
...@@ -146,7 +149,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -146,7 +149,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// contain initial score file // contain initial score file
if (init_score_ != nullptr && num_init_score_ != num_all_data) { if (init_score_ != nullptr && num_init_score_ != num_all_data) {
Log::Fatal("Initial score size doesn't equal to data"); Log::Fatal("Initial score size doesn't match data size");
delete[] init_score_; delete[] init_score_;
num_init_score_ = 0; num_init_score_ = 0;
init_score_ = nullptr; init_score_ = nullptr;
...@@ -200,9 +203,11 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -200,9 +203,11 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
if (init_score_ != nullptr) { if (init_score_ != nullptr) {
float* old_scores = init_score_; float* old_scores = init_score_;
num_init_score_ = num_data_; num_init_score_ = num_data_;
init_score_ = new float[num_init_score_]; init_score_ = new float[num_init_score_ * num_class_];
for (int k = 0; k < num_class_; ++k){
for (size_t i = 0; i < used_data_indices.size(); ++i) { for (size_t i = 0; i < used_data_indices.size(); ++i) {
init_score_[i] = old_scores[used_data_indices[i]]; init_score_[k * num_data_ + i] = old_scores[k * num_all_data + used_data_indices[i]];
}
} }
delete[] old_scores; delete[] old_scores;
} }
...@@ -214,13 +219,13 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -214,13 +219,13 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
void Metadata::SetInitScore(const float* init_score, data_size_t len) { void Metadata::SetInitScore(const float* init_score, data_size_t len) {
if (num_data_ != len) { if (len != num_data_ * num_class_) {
Log::Fatal("len of initial score is not same with #data"); Log::Fatal("Initial score size doesn't match data size");
} }
if (init_score_ != nullptr) { delete[] init_score_; } if (init_score_ != nullptr) { delete[] init_score_; }
num_init_score_ = num_data_; num_init_score_ = num_data_;
init_score_ = new float[num_init_score_]; init_score_ = new float[len];
for (data_size_t i = 0; i < num_init_score_; ++i) { for (data_size_t i = 0; i < len; ++i) {
init_score_[i] = init_score[i]; init_score_[i] = init_score[i];
} }
} }
...@@ -235,7 +240,7 @@ void Metadata::LoadWeights() { ...@@ -235,7 +240,7 @@ void Metadata::LoadWeights() {
if (reader.Lines().size() <= 0) { if (reader.Lines().size() <= 0) {
return; return;
} }
Log::Info("Start loading weights"); Log::Info("Loading weights...");
num_weights_ = static_cast<data_size_t>(reader.Lines().size()); num_weights_ = static_cast<data_size_t>(reader.Lines().size());
weights_ = new float[num_weights_]; weights_ = new float[num_weights_];
for (data_size_t i = 0; i < num_weights_; ++i) { for (data_size_t i = 0; i < num_weights_; ++i) {
...@@ -251,14 +256,30 @@ void Metadata::LoadInitialScore() { ...@@ -251,14 +256,30 @@ void Metadata::LoadInitialScore() {
TextReader<size_t> reader(init_score_filename_, false); TextReader<size_t> reader(init_score_filename_, false);
reader.ReadAllLines(); reader.ReadAllLines();
Log::Info("Start loading initial scores"); Log::Info("Loading initial scores...");
num_init_score_ = static_cast<data_size_t>(reader.Lines().size()); num_init_score_ = static_cast<data_size_t>(reader.Lines().size());
init_score_ = new float[num_init_score_];
init_score_ = new float[num_init_score_ * num_class_];
double tmp = 0.0f; double tmp = 0.0f;
if (num_class_ == 1){
for (data_size_t i = 0; i < num_init_score_; ++i) { for (data_size_t i = 0; i < num_init_score_; ++i) {
Common::Atof(reader.Lines()[i].c_str(), &tmp); Common::Atof(reader.Lines()[i].c_str(), &tmp);
init_score_[i] = static_cast<float>(tmp); init_score_[i] = static_cast<float>(tmp);
} }
} else {
std::vector<std::string> oneline_init_score;
for (data_size_t i = 0; i < num_init_score_; ++i) {
oneline_init_score = Common::Split(reader.Lines()[i].c_str(), '\t');
if (static_cast<int>(oneline_init_score.size()) != num_class_){
Log::Fatal("Invalid initial score file. Redundant or insufficient columns.");
}
for (int k = 0; k < num_class_; ++k) {
Common::Atof(oneline_init_score[k].c_str(), &tmp);
init_score_[k * num_init_score_ + i] = static_cast<float>(tmp);
}
}
}
} }
void Metadata::LoadQueryBoundaries() { void Metadata::LoadQueryBoundaries() {
...@@ -271,7 +292,7 @@ void Metadata::LoadQueryBoundaries() { ...@@ -271,7 +292,7 @@ void Metadata::LoadQueryBoundaries() {
if (reader.Lines().size() <= 0) { if (reader.Lines().size() <= 0) {
return; return;
} }
Log::Info("Start loading query boundries"); Log::Info("Loading query boundaries...");
query_boundaries_ = new data_size_t[reader.Lines().size() + 1]; query_boundaries_ = new data_size_t[reader.Lines().size() + 1];
num_queries_ = static_cast<data_size_t>(reader.Lines().size()); num_queries_ = static_cast<data_size_t>(reader.Lines().size());
query_boundaries_[0] = 0; query_boundaries_[0] = 0;
...@@ -286,7 +307,7 @@ void Metadata::LoadQueryWeights() { ...@@ -286,7 +307,7 @@ void Metadata::LoadQueryWeights() {
if (weights_ == nullptr || query_boundaries_ == nullptr) { if (weights_ == nullptr || query_boundaries_ == nullptr) {
return; return;
} }
Log::Info("Start loading query weights"); Log::Info("Loading query weights...");
query_weights_ = new float[num_queries_]; query_weights_ = new float[num_queries_];
for (data_size_t i = 0; i < num_queries_; ++i) { for (data_size_t i = 0; i < num_queries_; ++i) {
query_weights_[i] = 0.0f; query_weights_[i] = 0.0f;
...@@ -310,7 +331,7 @@ void Metadata::LoadFromMemory(const void* memory) { ...@@ -310,7 +331,7 @@ void Metadata::LoadFromMemory(const void* memory) {
if (label_ != nullptr) { delete[] label_; } if (label_ != nullptr) { delete[] label_; }
label_ = new float[num_data_]; label_ = new float[num_data_];
std::memcpy(label_, mem_ptr, sizeof(float)*num_data_); std::memcpy(label_, mem_ptr, sizeof(float)*num_data_);
mem_ptr += sizeof(float)*num_weights_; mem_ptr += sizeof(float)*num_data_;
if (num_weights_ > 0) { if (num_weights_ > 0) {
if (weights_ != nullptr) { delete[] weights_; } if (weights_ != nullptr) { delete[] weights_; }
......
...@@ -72,7 +72,7 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat ...@@ -72,7 +72,7 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat
std::ifstream tmp_file; std::ifstream tmp_file;
tmp_file.open(filename); tmp_file.open(filename);
if (!tmp_file.is_open()) { if (!tmp_file.is_open()) {
Log::Fatal("Data file: %s doesn't exist", filename); Log::Fatal("Data file %s doesn't exist'", filename);
} }
std::string line1, line2; std::string line1, line2;
if (has_header) { if (has_header) {
...@@ -83,12 +83,12 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat ...@@ -83,12 +83,12 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat
if (!tmp_file.eof()) { if (!tmp_file.eof()) {
std::getline(tmp_file, line1); std::getline(tmp_file, line1);
} else { } else {
Log::Fatal("Data file: %s at least should have one line", filename); Log::Fatal("Data file %s should have at least one line", filename);
} }
if (!tmp_file.eof()) { if (!tmp_file.eof()) {
std::getline(tmp_file, line2); std::getline(tmp_file, line2);
} else { } else {
Log::Warning("Data file: %s only have one line", filename); Log::Warning("Data file %s only has one line", filename);
} }
tmp_file.close(); tmp_file.close();
int comma_cnt = 0, comma_cnt2 = 0; int comma_cnt = 0, comma_cnt2 = 0;
...@@ -120,7 +120,7 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat ...@@ -120,7 +120,7 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat
} }
} }
if (type == DataType::INVALID) { if (type == DataType::INVALID) {
Log::Fatal("Unkown format of training data"); Log::Fatal("Unknown format of training data");
} }
Parser* ret = nullptr; Parser* ret = nullptr;
if (type == DataType::LIBSVM) { if (type == DataType::LIBSVM) {
...@@ -137,7 +137,7 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat ...@@ -137,7 +137,7 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat
} }
if (label_idx < 0) { if (label_idx < 0) {
Log::Info("Data file: %s doesn't contain label column", filename); Log::Info("Data file %s doesn't contain a label column", filename);
} }
return ret; return ret;
} }
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
if (*str == ',') { if (*str == ',') {
++str; ++str;
} else if (*str != '\0') { } else if (*str != '\0') {
Log::Fatal("input format error, should be CSV"); Log::Fatal("Input format error when parsing as CSV");
} }
} }
} }
...@@ -66,7 +66,7 @@ public: ...@@ -66,7 +66,7 @@ public:
if (*str == '\t') { if (*str == '\t') {
++str; ++str;
} else if (*str != '\0') { } else if (*str != '\0') {
Log::Fatal("input format error, should be TSV"); Log::Fatal("Input format error when parsing as TSV");
} }
} }
} }
...@@ -79,7 +79,7 @@ public: ...@@ -79,7 +79,7 @@ public:
explicit LibSVMParser(int label_idx) explicit LibSVMParser(int label_idx)
:label_idx_(label_idx) { :label_idx_(label_idx) {
if (label_idx > 0) { if (label_idx > 0) {
Log::Fatal("label should be the first column in Libsvm file"); Log::Fatal("Label should be the first column in a LibSVM file");
} }
} }
inline void ParseOneLine(const char* str, inline void ParseOneLine(const char* str,
...@@ -99,7 +99,7 @@ public: ...@@ -99,7 +99,7 @@ public:
str = Common::Atof(str, &val); str = Common::Atof(str, &val);
out_features->emplace_back(idx, val); out_features->emplace_back(idx, val);
} else { } else {
Log::Fatal("input format error, should be LibSVM"); Log::Fatal("Input format error when parsing as LibSVM");
} }
str = Common::SkipSpaceAndTab(str); str = Common::SkipSpaceAndTab(str);
} }
......
...@@ -28,7 +28,7 @@ public: ...@@ -28,7 +28,7 @@ public:
: num_data_(num_data) { : num_data_(num_data) {
default_bin_ = static_cast<VAL_T>(default_bin); default_bin_ = static_cast<VAL_T>(default_bin);
if (default_bin_ != 0) { if (default_bin_ != 0) {
Log::Info("Warning: Having sparse feature with negative values. Will let negative values equal zero as well"); Log::Info("Warning: sparse feature with negative values, treating negative values as zero");
} }
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
void ConstructHistogram(data_size_t*, data_size_t , const score_t* , void ConstructHistogram(data_size_t*, data_size_t , const score_t* ,
const score_t* , HistogramBinEntry*) const override { const score_t* , HistogramBinEntry*) const override {
// Will use OrderedSparseBin->ConstructHistogram() instead // Will use OrderedSparseBin->ConstructHistogram() instead
Log::Info("Should use OrderedSparseBin->ConstructHistogram() instead"); Log::Info("Using OrderedSparseBin->ConstructHistogram() instead");
} }
data_size_t Split(unsigned int threshold, data_size_t* data_indices, data_size_t num_data, data_size_t Split(unsigned int threshold, data_size_t* data_indices, data_size_t num_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment