Commit 2a8d38c5 authored by Qiwei Ye's avatar Qiwei Ye
Browse files

Merge branches 'master' and 'master' of https://github.com/Microsoft/LightGBM

parents 351b3d7e ed958eb2
...@@ -76,7 +76,7 @@ void Application::LoadParameters(int argc, char** argv) { ...@@ -76,7 +76,7 @@ void Application::LoadParameters(int argc, char** argv) {
ParameterAlias::KeyAliasTransform(&params); ParameterAlias::KeyAliasTransform(&params);
// read parameters from config file // read parameters from config file
if (params.count("config_file") > 0) { if (params.count("config_file") > 0) {
TextReader<size_t> config_reader(params["config_file"].c_str()); TextReader<size_t> config_reader(params["config_file"].c_str(), false);
config_reader.ReadAllLines(); config_reader.ReadAllLines();
if (config_reader.Lines().size() > 0) { if (config_reader.Lines().size() > 0) {
for (auto& line : config_reader.Lines()) { for (auto& line : config_reader.Lines()) {
...@@ -121,17 +121,14 @@ void Application::LoadData() { ...@@ -121,17 +121,14 @@ void Application::LoadData() {
// predition is needed if using input initial model(continued train) // predition is needed if using input initial model(continued train)
PredictFunction predict_fun = nullptr; PredictFunction predict_fun = nullptr;
Predictor* predictor = nullptr; Predictor* predictor = nullptr;
// load init model // need to continue train
if (config_.io_config.input_model.size() > 0) {
LoadModel();
if (boosting_->NumberOfSubModels() > 0) { if (boosting_->NumberOfSubModels() > 0) {
predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index); predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index, -1);
predict_fun = predict_fun =
[&predictor](const std::vector<std::pair<int, double>>& features) { [&predictor](const std::vector<std::pair<int, float>>& features) {
return predictor->PredictRawOneLine(features); return predictor->PredictRawOneLine(features);
}; };
} }
}
// sync up random seed for data partition // sync up random seed for data partition
if (config_.is_parallel_find_bin) { if (config_.is_parallel_find_bin) {
config_.io_config.data_random_seed = config_.io_config.data_random_seed =
...@@ -139,9 +136,7 @@ void Application::LoadData() { ...@@ -139,9 +136,7 @@ void Application::LoadData() {
} }
train_data_ = new Dataset(config_.io_config.data_filename.c_str(), train_data_ = new Dataset(config_.io_config.data_filename.c_str(),
config_.io_config.input_init_score.c_str(), config_.io_config.input_init_score.c_str(),
config_.io_config.max_bin, config_.io_config,
config_.io_config.data_random_seed,
config_.io_config.is_enable_sparse,
predict_fun); predict_fun);
// load Training data // load Training data
if (config_.is_parallel_find_bin) { if (config_.is_parallel_find_bin) {
...@@ -158,7 +153,7 @@ void Application::LoadData() { ...@@ -158,7 +153,7 @@ void Application::LoadData() {
train_data_->SaveBinaryFile(); train_data_->SaveBinaryFile();
} }
// create training metric // create training metric
if (config_.metric_config.is_provide_training_metric) { if (config_.boosting_config->is_provide_training_metric) {
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric_types) {
Metric* metric = Metric* metric =
Metric::CreateMetric(metric_type, config_.metric_config); Metric::CreateMetric(metric_type, config_.metric_config);
...@@ -173,9 +168,7 @@ void Application::LoadData() { ...@@ -173,9 +168,7 @@ void Application::LoadData() {
// add // add
valid_datas_.push_back( valid_datas_.push_back(
new Dataset(config_.io_config.valid_data_filenames[i].c_str(), new Dataset(config_.io_config.valid_data_filenames[i].c_str(),
config_.io_config.max_bin, config_.io_config,
config_.io_config.data_random_seed,
config_.io_config.is_enable_sparse,
predict_fun)); predict_fun));
// load validation data like train data // load validation data like train data
valid_datas_.back()->LoadValidationData(train_data_, valid_datas_.back()->LoadValidationData(train_data_,
...@@ -217,12 +210,13 @@ void Application::InitTrain() { ...@@ -217,12 +210,13 @@ void Application::InitTrain() {
gbdt_config->tree_config.feature_fraction_seed = gbdt_config->tree_config.feature_fraction_seed =
GlobalSyncUpByMin<int>(gbdt_config->tree_config.feature_fraction_seed); GlobalSyncUpByMin<int>(gbdt_config->tree_config.feature_fraction_seed);
gbdt_config->tree_config.feature_fraction = gbdt_config->tree_config.feature_fraction =
GlobalSyncUpByMin<double>(gbdt_config->tree_config.feature_fraction); GlobalSyncUpByMin<float>(gbdt_config->tree_config.feature_fraction);
} }
} }
// create boosting // create boosting
boosting_ = boosting_ =
Boosting::CreateBoosting(config_.boosting_type, config_.boosting_config); Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.input_model.c_str());
// create objective function // create objective function
objective_fun_ = objective_fun_ =
ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
...@@ -232,9 +226,8 @@ void Application::InitTrain() { ...@@ -232,9 +226,8 @@ void Application::InitTrain() {
// initialize the objective function // initialize the objective function
objective_fun_->Init(train_data_->metadata(), train_data_->num_data()); objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
// initialize the boosting // initialize the boosting
boosting_->Init(train_data_, objective_fun_, boosting_->Init(config_.boosting_config, train_data_, objective_fun_,
ConstPtrInVectorWarpper<Metric>(train_metric_), ConstPtrInVectorWarpper<Metric>(train_metric_));
config_.io_config.output_model.c_str());
// add validation data into boosting // add validation data into boosting
for (size_t i = 0; i < valid_datas_.size(); ++i) { for (size_t i = 0; i < valid_datas_.size(); ++i) {
boosting_->AddDataset(valid_datas_[i], boosting_->AddDataset(valid_datas_[i],
...@@ -244,36 +237,41 @@ void Application::InitTrain() { ...@@ -244,36 +237,41 @@ void Application::InitTrain() {
} }
void Application::Train() { void Application::Train() {
Log::Info("Start train"); Log::Info("Start train ...");
boosting_->Train(); int total_iter = config_.boosting_config->num_iterations;
Log::Info("Finish train"); bool is_finished = false;
bool need_eval = true;
auto start_time = std::chrono::high_resolution_clock::now();
for (int iter = 0; iter < total_iter && !is_finished; ++iter) {
is_finished = boosting_->TrainOneIter(nullptr, nullptr, need_eval);
auto end_time = std::chrono::high_resolution_clock::now();
// output used time per iteration
Log::Info("%f seconds elapsed, finished %d iteration", std::chrono::duration<double,
std::milli>(end_time - start_time) * 1e-3, iter + 1);
boosting_->SaveModelToFile(is_finished, config_.io_config.output_model.c_str());
}
is_finished = true;
// save model to file
boosting_->SaveModelToFile(is_finished, config_.io_config.output_model.c_str());
Log::Info("Finished train");
} }
void Application::Predict() { void Application::Predict() {
// create predictor // create predictor
Predictor predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index); Predictor predictor(boosting_, config_.io_config.is_sigmoid,
predictor.Predict(config_.io_config.data_filename.c_str(), config_.io_config.output_result.c_str()); config_.predict_leaf_index, config_.io_config.num_model_predict);
predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header);
Log::Info("Finish predict."); Log::Info("Finish predict.");
} }
void Application::InitPredict() { void Application::InitPredict() {
boosting_ = boosting_ =
Boosting::CreateBoosting(config_.boosting_type, config_.boosting_config); Boosting::CreateBoosting(config_.io_config.input_model.c_str());
LoadModel();
Log::Info("Finish predict initilization."); Log::Info("Finish predict initilization.");
} }
void Application::LoadModel() {
TextReader<size_t> model_reader(config_.io_config.input_model.c_str());
model_reader.ReadAllLines();
std::stringstream ss;
for (auto& line : model_reader.Lines()) {
ss << line << '\n';
}
boosting_->ModelsFromString(ss.str(), config_.io_config.num_model_predict);
}
template<typename T> template<typename T>
T Application::GlobalSyncUpByMin(T& local) { T Application::GlobalSyncUpByMin(T& local) {
T global = local; T global = local;
......
...@@ -28,18 +28,20 @@ public: ...@@ -28,18 +28,20 @@ public:
* \param is_sigmoid True if need to predict result with sigmoid transform(if needed, like binary classification) * \param is_sigmoid True if need to predict result with sigmoid transform(if needed, like binary classification)
* \param predict_leaf_index True if output leaf index instead of prediction score * \param predict_leaf_index True if output leaf index instead of prediction score
*/ */
Predictor(const Boosting* boosting, bool is_simgoid, bool predict_leaf_index) Predictor(const Boosting* boosting, bool is_simgoid, bool is_predict_leaf_index, int num_used_model)
: is_simgoid_(is_simgoid), predict_leaf_index(predict_leaf_index) { : is_simgoid_(is_simgoid), is_predict_leaf_index_(is_predict_leaf_index),
num_used_model_(num_used_model) {
boosting_ = boosting; boosting_ = boosting;
num_features_ = boosting_->MaxFeatureIdx() + 1; num_features_ = boosting_->MaxFeatureIdx() + 1;
num_class_ = boosting_->NumberOfClass();
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
{ {
num_threads_ = omp_get_num_threads(); num_threads_ = omp_get_num_threads();
} }
features_ = new double*[num_threads_]; features_ = new float*[num_threads_];
for (int i = 0; i < num_threads_; ++i) { for (int i = 0; i < num_threads_; ++i) {
features_[i] = new double[num_features_]; features_[i] = new float[num_features_];
} }
} }
/*! /*!
...@@ -59,10 +61,10 @@ public: ...@@ -59,10 +61,10 @@ public:
* \param features Feature for this record * \param features Feature for this record
* \return Prediction result * \return Prediction result
*/ */
double PredictRawOneLine(const std::vector<std::pair<int, double>>& features) { float PredictRawOneLine(const std::vector<std::pair<int, float>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result without sigmoid transformation // get result without sigmoid transformation
return boosting_->PredictRaw(features_[tid]); return boosting_->PredictRaw(features_[tid], num_used_model_);
} }
/*! /*!
...@@ -70,10 +72,10 @@ public: ...@@ -70,10 +72,10 @@ public:
* \param features Feature for this record * \param features Feature for this record
* \return Predictied leaf index * \return Predictied leaf index
*/ */
std::vector<int> PredictLeafIndexOneLine(const std::vector<std::pair<int, double>>& features) { std::vector<int> PredictLeafIndexOneLine(const std::vector<std::pair<int, float>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result for leaf index // get result for leaf index
return boosting_->PredictLeafIndex(features_[tid]); return boosting_->PredictLeafIndex(features_[tid], num_used_model_);
} }
/*! /*!
...@@ -81,18 +83,30 @@ public: ...@@ -81,18 +83,30 @@ public:
* \param features Feature of this record * \param features Feature of this record
* \return Prediction result * \return Prediction result
*/ */
double PredictOneLine(const std::vector<std::pair<int, double>>& features) { float PredictOneLine(const std::vector<std::pair<int, float>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed // get result with sigmoid transform if needed
return boosting_->Predict(features_[tid]); return boosting_->Predict(features_[tid], num_used_model_);
} }
/*!
* \brief prediction for multiclass classification
* \param features Feature of this record
* \return Prediction result
*/
std::vector<float> PredictMulticlassOneLine(const std::vector<std::pair<int, float>>& features) {
const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed
return boosting_->PredictMulticlass(features_[tid], num_used_model_);
}
/*! /*!
* \brief predicting on data, then saving result to disk * \brief predicting on data, then saving result to disk
* \param data_filename Filename of data * \param data_filename Filename of data
* \param has_label True if this data contains label * \param has_label True if this data contains label
* \param result_filename Filename of output result * \param result_filename Filename of output result
*/ */
void Predict(const char* data_filename, const char* result_filename) { void Predict(const char* data_filename, const char* result_filename, bool has_header) {
FILE* result_file; FILE* result_file;
#ifdef _MSC_VER #ifdef _MSC_VER
...@@ -104,53 +118,55 @@ public: ...@@ -104,53 +118,55 @@ public:
if (result_file == NULL) { if (result_file == NULL) {
Log::Fatal("Predition result file %s doesn't exists", data_filename); Log::Fatal("Predition result file %s doesn't exists", data_filename);
} }
bool has_label = false; Parser* parser = Parser::CreateParser(data_filename, has_header, num_features_, boosting_->LabelIdx());
Parser* parser = Parser::CreateParser(data_filename, num_features_, &has_label);
if (parser == nullptr) { if (parser == nullptr) {
Log::Fatal("Recongnizing input data format failed, filename %s", data_filename); Log::Fatal("Recongnizing input data format failed, filename %s", data_filename);
} }
// function for parse data // function for parse data
std::function<void(const char*, std::vector<std::pair<int, double>>*)> parser_fun; std::function<void(const char*, std::vector<std::pair<int, float>>*)> parser_fun;
double tmp_label; float tmp_label;
if (has_label) {
// parse function with label
parser_fun = [this, &parser, &tmp_label] parser_fun = [this, &parser, &tmp_label]
(const char* buffer, std::vector<std::pair<int, double>>* feature) { (const char* buffer, std::vector<std::pair<int, float>>* feature) {
parser->ParseOneLine(buffer, feature, &tmp_label); parser->ParseOneLine(buffer, feature, &tmp_label);
}; };
Log::Info("Start prediction for data %s with labels", data_filename);
} else { std::function<std::string(const std::vector<std::pair<int, float>>&)> predict_fun;
// parse function without label if (num_class_ > 1) {
parser_fun = [this, &parser] predict_fun = [this](const std::vector<std::pair<int, float>>& features){
(const char* buffer, std::vector<std::pair<int, double>>* feature) { std::vector<float> prediction = PredictMulticlassOneLine(features);
parser->ParseOneLine(buffer, feature); std::stringstream result_stream_buf;
for (size_t i = 0; i < prediction.size(); ++i){
if (i > 0) {
result_stream_buf << '\t';
}
result_stream_buf << prediction[i];
}
return result_stream_buf.str();
}; };
Log::Info("Start prediction for data %s without label", data_filename);
} }
std::function<std::string(const std::vector<std::pair<int, double>>&)> predict_fun; else if (is_predict_leaf_index_) {
if (predict_leaf_index) { predict_fun = [this](const std::vector<std::pair<int, float>>& features){
predict_fun = [this](const std::vector<std::pair<int, double>>& features){
std::vector<int> predicted_leaf_index = PredictLeafIndexOneLine(features); std::vector<int> predicted_leaf_index = PredictLeafIndexOneLine(features);
std::stringstream result_ss; std::stringstream result_stream_buf;
for (size_t i = 0; i < predicted_leaf_index.size(); ++i){ for (size_t i = 0; i < predicted_leaf_index.size(); ++i){
if (i > 0) { if (i > 0) {
result_ss << '\t'; result_stream_buf << '\t';
} }
result_ss << predicted_leaf_index[i]; result_stream_buf << predicted_leaf_index[i];
} }
return result_ss.str(); return result_stream_buf.str();
}; };
} }
else { else {
if (is_simgoid_) { if (is_simgoid_) {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){ predict_fun = [this](const std::vector<std::pair<int, float>>& features){
return std::to_string(PredictOneLine(features)); return std::to_string(PredictOneLine(features));
}; };
} }
else { else {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){ predict_fun = [this](const std::vector<std::pair<int, float>>& features){
return std::to_string(PredictRawOneLine(features)); return std::to_string(PredictRawOneLine(features));
}; };
} }
...@@ -158,10 +174,10 @@ public: ...@@ -158,10 +174,10 @@ public:
std::function<void(data_size_t, const std::vector<std::string>&)> process_fun = std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
[this, &parser_fun, &predict_fun, &result_file] [this, &parser_fun, &predict_fun, &result_file]
(data_size_t, const std::vector<std::string>& lines) { (data_size_t, const std::vector<std::string>& lines) {
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, float>> oneline_features;
std::vector<std::string> pred_result(lines.size(), ""); std::vector<std::string> pred_result(lines.size(), "");
#pragma omp parallel for schedule(static) private(oneline_features) #pragma omp parallel for schedule(static) private(oneline_features)
for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); i++) { for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
oneline_features.clear(); oneline_features.clear();
// parser // parser
parser_fun(lines[i].c_str(), &oneline_features); parser_fun(lines[i].c_str(), &oneline_features);
...@@ -173,7 +189,7 @@ public: ...@@ -173,7 +189,7 @@ public:
fprintf(result_file, "%s\n", pred_result[i].c_str()); fprintf(result_file, "%s\n", pred_result[i].c_str());
} }
}; };
TextReader<data_size_t> predict_data_reader(data_filename); TextReader<data_size_t> predict_data_reader(data_filename, has_header);
predict_data_reader.ReadAllAndProcessParallel(process_fun); predict_data_reader.ReadAllAndProcessParallel(process_fun);
fclose(result_file); fclose(result_file);
...@@ -181,10 +197,10 @@ public: ...@@ -181,10 +197,10 @@ public:
} }
private: private:
int PutFeatureValuesToBuffer(const std::vector<std::pair<int, double>>& features) { int PutFeatureValuesToBuffer(const std::vector<std::pair<int, float>>& features) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
// init feature value // init feature value
std::memset(features_[tid], 0, sizeof(double)*num_features_); std::memset(features_[tid], 0, sizeof(float)*num_features_);
// put feature value // put feature value
for (const auto& p : features) { for (const auto& p : features) {
if (p.first < num_features_) { if (p.first < num_features_) {
...@@ -196,15 +212,19 @@ private: ...@@ -196,15 +212,19 @@ private:
/*! \brief Boosting model */ /*! \brief Boosting model */
const Boosting* boosting_; const Boosting* boosting_;
/*! \brief Buffer for feature values */ /*! \brief Buffer for feature values */
double** features_; float** features_;
/*! \brief Number of features */ /*! \brief Number of features */
int num_features_; int num_features_;
/*! \brief Number of classes */
int num_class_;
/*! \brief True if need to predict result with sigmoid transform */ /*! \brief True if need to predict result with sigmoid transform */
bool is_simgoid_; bool is_simgoid_;
/*! \brief Number of threads */ /*! \brief Number of threads */
int num_threads_; int num_threads_;
/*! \brief True if output leaf index instead of prediction score */ /*! \brief True if output leaf index instead of prediction score */
bool predict_leaf_index; bool is_predict_leaf_index_;
/*! \brief Number of used model */
int num_used_model_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -3,13 +3,57 @@ ...@@ -3,13 +3,57 @@
namespace LightGBM { namespace LightGBM {
Boosting* Boosting::CreateBoosting(BoostingType type, BoostingType GetBoostingTypeFromModelFile(const char* filename) {
const BoostingConfig* config) { TextReader<size_t> model_reader(filename, true);
std::string type = model_reader.first_line();
if (type == std::string("gbdt")) {
return BoostingType::kGBDT;
}
return BoostingType::kUnknow;
}
void LoadFileToBoosting(Boosting* boosting, const char* filename) {
if (boosting != nullptr) {
TextReader<size_t> model_reader(filename, true);
model_reader.ReadAllLines();
std::stringstream str_buf;
for (auto& line : model_reader.Lines()) {
str_buf << line << '\n';
}
boosting->ModelsFromString(str_buf.str());
}
}
Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
if (filename[0] == '\0') {
if (type == BoostingType::kGBDT) { if (type == BoostingType::kGBDT) {
return new GBDT(config); return new GBDT();
} else { } else {
return nullptr; return nullptr;
} }
} else {
Boosting* ret = nullptr;
auto type_in_file = GetBoostingTypeFromModelFile(filename);
if (type_in_file == type) {
if (type == BoostingType::kGBDT) {
ret = new GBDT();
}
LoadFileToBoosting(ret, filename);
} else {
Log::Fatal("Boosting type in parameter is not same with the type in model file");
}
return ret;
}
}
Boosting* Boosting::CreateBoosting(const char* filename) {
auto type = GetBoostingTypeFromModelFile(filename);
Boosting* ret = nullptr;
if (type == BoostingType::kGBDT) {
ret = new GBDT();
}
LoadFileToBoosting(ret, filename);
return ret;
} }
} // namespace LightGBM } // namespace LightGBM
...@@ -12,20 +12,20 @@ ...@@ -12,20 +12,20 @@
#include <chrono> #include <chrono>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility>
namespace LightGBM { namespace LightGBM {
GBDT::GBDT(const BoostingConfig* config) GBDT::GBDT()
: tree_learner_(nullptr), train_score_updater_(nullptr), : train_score_updater_(nullptr),
gradients_(nullptr), hessians_(nullptr), gradients_(nullptr), hessians_(nullptr),
out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr) { out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr) {
max_feature_idx_ = 0;
gbdt_config_ = dynamic_cast<const GBDTConfig*>(config);
early_stopping_round_ = gbdt_config_->early_stopping_round;
} }
GBDT::~GBDT() { GBDT::~GBDT() {
if (tree_learner_ != nullptr) { delete tree_learner_; } for (auto& tree_learner: tree_learner_){
if (tree_learner != nullptr) { delete tree_learner; }
}
if (gradients_ != nullptr) { delete[] gradients_; } if (gradients_ != nullptr) { delete[] gradients_; }
if (hessians_ != nullptr) { delete[] hessians_; } if (hessians_ != nullptr) { delete[] hessians_; }
if (out_of_bag_data_indices_ != nullptr) { delete[] out_of_bag_data_indices_; } if (out_of_bag_data_indices_ != nullptr) { delete[] out_of_bag_data_indices_; }
...@@ -39,29 +39,40 @@ GBDT::~GBDT() { ...@@ -39,29 +39,40 @@ GBDT::~GBDT() {
} }
} }
void GBDT::Init(const Dataset* train_data, const ObjectiveFunction* object_function, void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics, const char* output_model_filename) { const std::vector<const Metric*>& training_metrics) {
gbdt_config_ = dynamic_cast<const GBDTConfig*>(config);
iter_ = 0;
max_feature_idx_ = 0;
early_stopping_round_ = gbdt_config_->early_stopping_round;
train_data_ = train_data; train_data_ = train_data;
num_class_ = config->num_class;
tree_learner_ = std::vector<TreeLearner*>(num_class_, nullptr);
// create tree learner // create tree learner
tree_learner_ = for (int i = 0; i < num_class_; ++i){
tree_learner_[i] =
TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config); TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config);
// init tree learner // init tree learner
tree_learner_->Init(train_data_); tree_learner_[i]->Init(train_data_);
}
object_function_ = object_function; object_function_ = object_function;
// push training metrics // push training metrics
for (const auto& metric : training_metrics) { for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric); training_metrics_.push_back(metric);
} }
// create score tracker // create score tracker
train_score_updater_ = new ScoreUpdater(train_data_); train_score_updater_ = new ScoreUpdater(train_data_, num_class_);
num_data_ = train_data_->num_data(); num_data_ = train_data_->num_data();
// create buffer for gradients and hessians // create buffer for gradients and hessians
gradients_ = new score_t[num_data_]; if (object_function_ != nullptr) {
hessians_ = new score_t[num_data_]; gradients_ = new score_t[num_data_ * num_class_];
hessians_ = new score_t[num_data_ * num_class_];
}
// get max feature index // get max feature index
max_feature_idx_ = train_data_->num_total_features() - 1; max_feature_idx_ = train_data_->num_total_features() - 1;
// get label index
label_idx_ = train_data_->label_idx();
// if need bagging, create buffer // if need bagging, create buffer
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) { if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
out_of_bag_data_indices_ = new data_size_t[num_data_]; out_of_bag_data_indices_ = new data_size_t[num_data_];
...@@ -75,22 +86,12 @@ void GBDT::Init(const Dataset* train_data, const ObjectiveFunction* object_funct ...@@ -75,22 +86,12 @@ void GBDT::Init(const Dataset* train_data, const ObjectiveFunction* object_funct
// initialize random generator // initialize random generator
random_ = Random(gbdt_config_->bagging_seed); random_ = Random(gbdt_config_->bagging_seed);
// open model output file
#ifdef _MSC_VER
fopen_s(&output_model_file, output_model_filename, "w");
#else
output_model_file = fopen(output_model_filename, "w");
#endif
// output models
fprintf(output_model_file, "%s", this->ModelsToString().c_str());
} }
void GBDT::AddDataset(const Dataset* valid_data, void GBDT::AddDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) { const std::vector<const Metric*>& valid_metrics) {
// for a validation dataset, we need its score and metric // for a validation dataset, we need its score and metric
valid_score_updater_.push_back(new ScoreUpdater(valid_data)); valid_score_updater_.push_back(new ScoreUpdater(valid_data, num_class_));
valid_metrics_.emplace_back(); valid_metrics_.emplace_back();
best_iter_.emplace_back(); best_iter_.emplace_back();
best_score_.emplace_back(); best_score_.emplace_back();
...@@ -102,7 +103,7 @@ void GBDT::AddDataset(const Dataset* valid_data, ...@@ -102,7 +103,7 @@ void GBDT::AddDataset(const Dataset* valid_data,
} }
void GBDT::Bagging(int iter) { void GBDT::Bagging(int iter, const int curr_class) {
// if need bagging // if need bagging
if (out_of_bag_data_indices_ != nullptr && iter % gbdt_config_->bagging_freq == 0) { if (out_of_bag_data_indices_ != nullptr && iter % gbdt_config_->bagging_freq == 0) {
// if doesn't have query data // if doesn't have query data
...@@ -151,150 +152,244 @@ void GBDT::Bagging(int iter) { ...@@ -151,150 +152,244 @@ void GBDT::Bagging(int iter) {
} }
Log::Info("re-bagging, using %d data to train", bag_data_cnt_); Log::Info("re-bagging, using %d data to train", bag_data_cnt_);
// set bagging data to tree learner // set bagging data to tree learner
tree_learner_->SetBaggingData(bag_data_indices_, bag_data_cnt_); tree_learner_[curr_class]->SetBaggingData(bag_data_indices_, bag_data_cnt_);
} }
} }
void GBDT::UpdateScoreOutOfBag(const Tree* tree) { void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
// we need to predict out-of-bag socres of data for boosting // we need to predict out-of-bag socres of data for boosting
if (out_of_bag_data_indices_ != nullptr) { if (out_of_bag_data_indices_ != nullptr) {
train_score_updater_-> train_score_updater_->
AddScore(tree, out_of_bag_data_indices_, out_of_bag_data_cnt_); AddScore(tree, out_of_bag_data_indices_, out_of_bag_data_cnt_, curr_class);
} }
} }
void GBDT::Train() { bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
// training start time
auto start_time = std::chrono::high_resolution_clock::now();
for (int iter = 0; iter < gbdt_config_->num_iterations; ++iter) {
// boosting first // boosting first
if (gradient == nullptr || hessian == nullptr) {
Boosting(); Boosting();
gradient = gradients_;
hessian = hessians_;
}
for (int curr_class = 0; curr_class < num_class_; ++curr_class){
// bagging logic // bagging logic
Bagging(iter); Bagging(iter_, curr_class);
// train a new tree // train a new tree
Tree * new_tree = TrainOneTree(); Tree * new_tree = tree_learner_[curr_class]->Train(gradient + curr_class * num_data_, hessian+ curr_class * num_data_);
// if cannot learn a new tree, then stop // if cannot learn a new tree, then stop
if (new_tree->num_leaves() <= 1) { if (new_tree->num_leaves() <= 1) {
Log::Info("Can't training anymore, there isn't any leaf meets split requirements."); Log::Info("Can't training anymore, there isn't any leaf meets split requirements.");
break; return true;
} }
// shrinkage by learning rate // shrinkage by learning rate
new_tree->Shrinkage(gbdt_config_->learning_rate); new_tree->Shrinkage(gbdt_config_->learning_rate);
// update score // update score
UpdateScore(new_tree); UpdateScore(new_tree, curr_class);
UpdateScoreOutOfBag(new_tree); UpdateScoreOutOfBag(new_tree, curr_class);
// print message for metric
bool is_early_stopping = OutputMetric(iter + 1);
// add model // add model
models_.push_back(new_tree); models_.push_back(new_tree);
// save model to file per iteration
if (early_stopping_round_ > 0){
// if use early stopping, save previous model at (iter - early_stopping_round_) iteration
if (iter >= early_stopping_round_){
fprintf(output_model_file, "Tree=%d\n", iter - early_stopping_round_);
Tree * printing_tree = models_.at(iter - early_stopping_round_);
fprintf(output_model_file, "%s\n", printing_tree->ToString().c_str());
fflush(output_model_file);
}
}
else{
fprintf(output_model_file, "Tree=%d\n", iter);
fprintf(output_model_file, "%s\n", new_tree->ToString().c_str());
fflush(output_model_file);
}
auto end_time = std::chrono::high_resolution_clock::now();
// output used time per iteration
Log::Info("%f seconds elapsed, finished %d iteration", std::chrono::duration<double,
std::milli>(end_time - start_time) * 1e-3, iter + 1);
if (is_early_stopping) {
// close file with an early-stopping message
Log::Info("Early stopping at iteration %d, the best iteration round is %d", iter + 1, iter + 1 - early_stopping_round_);
fclose(output_model_file);
return;
} }
bool is_met_early_stopping = false;
// print message for metric
if (is_eval) {
is_met_early_stopping = OutputMetric(iter_ + 1);
} }
// close file ++iter_;
if (early_stopping_round_ > 0) { if (is_met_early_stopping) {
// save remaining models Log::Info("Early stopping at iteration %d, the best iteration round is %d",
for (int iter = gbdt_config_->num_iterations - early_stopping_round_; iter < static_cast<int>(models_.size()); ++iter){ iter_, iter_ - early_stopping_round_);
fprintf(output_model_file, "Tree=%d\n", iter); // pop last early_stopping_round_ models
fprintf(output_model_file, "%s\n", models_.at(iter)->ToString().c_str()); for (int i = 0; i < early_stopping_round_ * num_class_; ++i) {
delete models_.back();
models_.pop_back();
} }
fflush(output_model_file);
} }
fclose(output_model_file); return is_met_early_stopping;
}
Tree* GBDT::TrainOneTree() {
return tree_learner_->Train(gradients_, hessians_);
} }
void GBDT::UpdateScore(const Tree* tree) { void GBDT::UpdateScore(const Tree* tree, const int curr_class) {
// update training score // update training score
train_score_updater_->AddScore(tree_learner_); train_score_updater_->AddScore(tree_learner_[curr_class], curr_class);
// update validation score // update validation score
for (auto& score_tracker : valid_score_updater_) { for (auto& score_updater : valid_score_updater_) {
score_tracker->AddScore(tree); score_updater->AddScore(tree, curr_class);
} }
} }
bool GBDT::OutputMetric(int iter) { bool GBDT::OutputMetric(int iter) {
bool ret = false; bool ret = false;
// print training metric // print training metric
if ((iter % gbdt_config_->output_freq) == 0) {
for (auto& sub_metric : training_metrics_) { for (auto& sub_metric : training_metrics_) {
sub_metric->PrintAndGetLoss(iter, train_score_updater_->score()); auto name = sub_metric->GetName();
auto scores = sub_metric->Eval(train_score_updater_->score());
Log::Info("Iteration:%d, %s : %s", iter, name, Common::ArrayToString<float>(scores, ' ').c_str());
}
} }
// print validation metric // print validation metric
if ((iter % gbdt_config_->output_freq) == 0 || early_stopping_round_ > 0) {
for (size_t i = 0; i < valid_metrics_.size(); ++i) { for (size_t i = 0; i < valid_metrics_.size(); ++i) {
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) { for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
score_t test_score_ = valid_metrics_[i][j]->PrintAndGetLoss(iter, valid_score_updater_[i]->score()); auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score());
if (!ret && early_stopping_round_ > 0){ if ((iter % gbdt_config_->output_freq) == 0) {
bool the_bigger_the_better_ = valid_metrics_[i][j]->the_bigger_the_better; auto name = valid_metrics_[i][j]->GetName();
Log::Info("Iteration:%d, %s : %s", iter, name, Common::ArrayToString<float>(test_scores, ' ').c_str());
}
if (!ret && early_stopping_round_ > 0) {
bool the_bigger_the_better = valid_metrics_[i][j]->is_bigger_better();
if (best_score_[i][j] < 0 if (best_score_[i][j] < 0
|| (!the_bigger_the_better_ && test_score_ < best_score_[i][j]) || (!the_bigger_the_better && test_scores.back() < best_score_[i][j])
|| ( the_bigger_the_better_ && test_score_ > best_score_[i][j])){ || (the_bigger_the_better && test_scores.back() > best_score_[i][j])) {
best_score_[i][j] = test_score_; best_score_[i][j] = test_scores.back();
best_iter_[i][j] = iter; best_iter_[i][j] = iter;
} } else {
else {
if (iter - best_iter_[i][j] >= early_stopping_round_) ret = true; if (iter - best_iter_[i][j] >= early_stopping_round_) ret = true;
} }
} }
} }
} }
}
return ret;
}
/*! \brief Get eval result */
std::vector<std::string> GBDT::EvalCurrent(bool is_eval_train) const {
std::vector<std::string> ret;
if (is_eval_train) {
for (auto& sub_metric : training_metrics_) {
auto name = sub_metric->GetName();
auto scores = sub_metric->Eval(train_score_updater_->score());
std::stringstream str_buf;
str_buf << name << " : " << Common::ArrayToString<float>(scores, ' ');
ret.emplace_back(str_buf.str());
}
}
for (size_t i = 0; i < valid_metrics_.size(); ++i) {
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
auto name = valid_metrics_[i][j]->GetName();
auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score());
std::stringstream str_buf;
str_buf << name << " : " << Common::ArrayToString<float>(test_scores, ' ');
ret.emplace_back(str_buf.str());
}
}
return ret;
}
/*! \brief Get prediction result */
const std::vector<const score_t*> GBDT::PredictCurrent(bool is_predict_train) const {
std::vector<const score_t*> ret;
if (is_predict_train) {
ret.push_back(train_score_updater_->score());
}
for (size_t i = 0; i < valid_metrics_.size(); ++i) {
ret.push_back(valid_score_updater_[i]->score());
}
return ret; return ret;
} }
void GBDT::Boosting() { void GBDT::Boosting() {
if (object_function_ == nullptr) {
Log::Fatal("No object function provided");
}
// objective function will calculate gradients and hessians // objective function will calculate gradients and hessians
object_function_-> object_function_->
GetGradients(train_score_updater_->score(), gradients_, hessians_); GetGradients(train_score_updater_->score(), gradients_, hessians_);
} }
void GBDT::SaveModelToFile(bool is_finish, const char* filename) {
std::string GBDT::ModelsToString() const { // first time to this function, open file
// serialize this object to string if (saved_model_size_ == -1) {
std::stringstream ss; model_output_file_.open(filename);
// output model type
model_output_file_ << "gbdt" << std::endl;
// output number of class
model_output_file_ << "num_class=" << num_class_ << std::endl;
// output label index
model_output_file_ << "label_index=" << label_idx_ << std::endl;
// output max_feature_idx // output max_feature_idx
ss << "max_feature_idx=" << max_feature_idx_ << std::endl; model_output_file_ << "max_feature_idx=" << max_feature_idx_ << std::endl;
// output sigmoid parameter // output sigmoid parameter
ss << "sigmoid=" << object_function_->GetSigmoid() << std::endl; model_output_file_ << "sigmoid=" << object_function_->GetSigmoid() << std::endl;
ss << std::endl; model_output_file_ << std::endl;
saved_model_size_ = 0;
}
// already saved
if (!model_output_file_.is_open()) {
return;
}
int rest = static_cast<int>(models_.size()) - early_stopping_round_ * num_class_;
// output tree models // output tree models
for (size_t i = 0; i < models_.size(); ++i) { for (int i = saved_model_size_; i < rest; ++i) {
ss << "Tree=" << i << std::endl; model_output_file_ << "Tree=" << i << std::endl;
ss << models_[i]->ToString() << std::endl; model_output_file_ << models_[i]->ToString() << std::endl;
}
saved_model_size_ = Common::Max(saved_model_size_, rest);
model_output_file_.flush();
// training finished, can close file
if (is_finish) {
for (int i = saved_model_size_; i < static_cast<int>(models_.size()); ++i) {
model_output_file_ << "Tree=" << i << std::endl;
model_output_file_ << models_[i]->ToString() << std::endl;
}
model_output_file_ << std::endl << FeatureImportance() << std::endl;
model_output_file_.close();
} }
return ss.str();
} }
void GBDT::ModelsFromString(const std::string& model_str, int num_used_model) { void GBDT::ModelsFromString(const std::string& model_str) {
// use serialized string to restore this object // use serialized string to restore this object
models_.clear(); models_.clear();
std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n'); std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n');
size_t i = 0; size_t i = 0;
// get number of class
while (i < lines.size()) {
size_t find_pos = lines[i].find("num_class=");
if (find_pos != std::string::npos) {
std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
Common::Atoi(strs[1].c_str(), &num_class_);
++i;
break;
} else {
++i;
}
}
if (i == lines.size()) {
Log::Fatal("Model file doesn't contain number of class");
return;
}
// get index of label
i = 0;
while (i < lines.size()) {
size_t find_pos = lines[i].find("label_index=");
if (find_pos != std::string::npos) {
std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
Common::Atoi(strs[1].c_str(), &label_idx_);
++i;
break;
} else {
++i;
}
}
if (i == lines.size()) {
Log::Fatal("Model file doesn't contain label index");
return;
}
// get max_feature_idx first // get max_feature_idx first
i = 0;
while (i < lines.size()) { while (i < lines.size()) {
size_t find_pos = lines[i].find("max_feature_idx="); size_t find_pos = lines[i].find("max_feature_idx=");
if (find_pos != std::string::npos) { if (find_pos != std::string::npos) {
...@@ -338,40 +433,86 @@ void GBDT::ModelsFromString(const std::string& model_str, int num_used_model) { ...@@ -338,40 +433,86 @@ void GBDT::ModelsFromString(const std::string& model_str, int num_used_model) {
int end = static_cast<int>(i); int end = static_cast<int>(i);
std::string tree_str = Common::Join(lines, start, end, '\n'); std::string tree_str = Common::Join(lines, start, end, '\n');
models_.push_back(new Tree(tree_str)); models_.push_back(new Tree(tree_str));
if (num_used_model > 0 && models_.size() >= static_cast<size_t>(num_used_model)) {
break;
}
} else { } else {
++i; ++i;
} }
} }
Log::Info("%d models has been loaded\n", models_.size()); Log::Info("%d models has been loaded\n", models_.size());
} }
double GBDT::PredictRaw(const double* value) const { std::string GBDT::FeatureImportance() const {
double ret = 0.0; std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
for (size_t i = 0; i < models_.size(); ++i) { for (size_t iter = 0; iter < models_.size(); ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
++feature_importances[models_[iter]->split_feature_real(split_idx)];
}
}
// store the importance first
std::vector<std::pair<size_t, std::string>> pairs;
for (size_t i = 0; i < feature_importances.size(); ++i) {
pairs.emplace_back(feature_importances[i], train_data_->feature_names()[i]);
}
// sort the importance
std::sort(pairs.begin(), pairs.end(),
[](const std::pair<size_t, std::string>& lhs,
const std::pair<size_t, std::string>& rhs) {
return lhs.first > rhs.first;
});
std::stringstream str_buf;
// write to model file
str_buf << std::endl << "feature importances:" << std::endl;
for (size_t i = 0; i < pairs.size(); ++i) {
str_buf << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
}
return str_buf.str();
}
float GBDT::PredictRaw(const float* value, int num_used_model) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size());
}
float ret = 0.0f;
for (int i = 0; i < num_used_model; ++i) {
ret += models_[i]->Predict(value); ret += models_[i]->Predict(value);
} }
return ret; return ret;
} }
double GBDT::Predict(const double* value) const { float GBDT::Predict(const float* value, int num_used_model) const {
double ret = 0.0; if (num_used_model < 0) {
for (size_t i = 0; i < models_.size(); ++i) { num_used_model = static_cast<int>(models_.size());
}
float ret = 0.0f;
for (int i = 0; i < num_used_model; ++i) {
ret += models_[i]->Predict(value); ret += models_[i]->Predict(value);
} }
// if need sigmoid transform // if need sigmoid transform
if (sigmoid_ > 0) { if (sigmoid_ > 0) {
ret = 1.0 / (1.0 + std::exp(- 2.0f * sigmoid_ * ret)); ret = 1.0f / (1.0f + std::exp(- 2.0f * sigmoid_ * ret));
}
return ret;
}
std::vector<float> GBDT::PredictMulticlass(const float* value, int num_used_model) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size()) / num_class_;
} }
std::vector<float> ret(num_class_, 0.0f);
for (int i = 0; i < num_used_model; ++i) {
for (int j = 0; j < num_class_; ++j){
ret[j] += models_[i * num_class_ + j] -> Predict(value);
}
}
Common::Softmax(&ret);
return ret; return ret;
} }
std::vector<int> GBDT::PredictLeafIndex(const double* value) const { std::vector<int> GBDT::PredictLeafIndex(const float* value, int num_used_model) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size());
}
std::vector<int> ret; std::vector<int> ret;
for (size_t i = 0; i < models_.size(); ++i) { for (int i = 0; i < num_used_model; ++i) {
ret.push_back(models_[i]->PredictLeafIndex(value)); ret.push_back(models_[i]->PredictLeafIndex(value));
} }
return ret; return ret;
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <cstdio> #include <cstdio>
#include <vector> #include <vector>
#include <string> #include <string>
#include <fstream>
namespace LightGBM { namespace LightGBM {
/*! /*!
...@@ -16,9 +17,8 @@ class GBDT: public Boosting { ...@@ -16,9 +17,8 @@ class GBDT: public Boosting {
public: public:
/*! /*!
* \brief Constructor * \brief Constructor
* \param config Config of GBDT
*/ */
explicit GBDT(const BoostingConfig* config); GBDT();
/*! /*!
* \brief Destructor * \brief Destructor
*/ */
...@@ -31,9 +31,8 @@ public: ...@@ -31,9 +31,8 @@ public:
* \param training_metrics Training metrics * \param training_metrics Training metrics
* \param output_model_filename Filename of output model * \param output_model_filename Filename of output model
*/ */
void Init(const Dataset* train_data, const ObjectiveFunction* object_function, void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics, const std::vector<const Metric*>& training_metrics)
const char* output_model_filename)
override; override;
/*! /*!
* \brief Adding a validation dataset * \brief Adding a validation dataset
...@@ -45,92 +44,128 @@ public: ...@@ -45,92 +44,128 @@ public:
/*! /*!
* \brief one training iteration * \brief one training iteration
*/ */
void Train() override; bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override;
/*! \brief Get eval result */
std::vector<std::string> EvalCurrent(bool is_eval_train) const override;
/*! \brief Get prediction result */
const std::vector<const score_t*> PredictCurrent(bool is_predict_train) const override;
/*! /*!
* \brief Predtion for one record without sigmoid transformation * \brief Predtion for one record without sigmoid transformation
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record * \return Prediction result for this record
*/ */
double PredictRaw(const double * feature_values) const override; float PredictRaw(const float* feature_values, int num_used_model) const override;
/*! /*!
* \brief Predtion for one record with sigmoid transformation if enabled * \brief Predtion for one record with sigmoid transformation if enabled
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record * \return Prediction result for this record
*/ */
double Predict(const double * feature_values) const override; float Predict(const float* feature_values, int num_used_model) const override;
/*!
* \brief Predtion for multiclass classification
* \param feature_values Feature value on this record
* \return Prediction result, num_class numbers per line
*/
std::vector<float> PredictMulticlass(const float* value, int num_used_model) const override;
/*! /*!
* \brief Predtion for one record with leaf index * \brief Predtion for one record with leaf index
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Predicted leaf index for this record * \return Predicted leaf index for this record
*/ */
std::vector<int> PredictLeafIndex(const double* value) const override; std::vector<int> PredictLeafIndex(const float* value, int num_used_model) const override;
/*! /*!
* \brief Serialize models by string * \brief Serialize models by string
* \return String output of tranined model * \return String output of tranined model
*/ */
std::string ModelsToString() const override; void SaveModelToFile(bool is_finish, const char* filename) override;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
* \param model_str The string of model
*/ */
void ModelsFromString(const std::string& model_str, int num_used_model) override; void ModelsFromString(const std::string& model_str) override;
/*! /*!
* \brief Get max feature index of this model * \brief Get max feature index of this model
* \return Max feature index of this model * \return Max feature index of this model
*/ */
inline int MaxFeatureIdx() const override { return max_feature_idx_; } inline int MaxFeatureIdx() const override { return max_feature_idx_; }
/*!
* \brief Get index of label column
* \return index of label column
*/
inline int LabelIdx() const override { return label_idx_; }
/*! /*!
* \brief Get number of weak sub-models * \brief Get number of weak sub-models
* \return Number of weak sub-models * \return Number of weak sub-models
*/ */
inline int NumberOfSubModels() const override { return static_cast<int>(models_.size()); } inline int NumberOfSubModels() const override { return static_cast<int>(models_.size()); }
/*!
* \brief Get number of classes
* \return Number of classes
*/
inline int NumberOfClass() const override { return num_class_; }
/*!
* \brief Get Type name of this boosting object
*/
const char* Name() const override { return "gbdt"; }
private: private:
/*! /*!
* \brief Implement bagging logic * \brief Implement bagging logic
* \param iter Current interation * \param iter Current interation
* \param curr_class Current class for multiclass training
*/ */
void Bagging(int iter); void Bagging(int iter, const int curr_class);
/*! /*!
* \brief updating score for out-of-bag data. * \brief updating score for out-of-bag data.
* Data should be update since we may re-bagging data on training * Data should be update since we may re-bagging data on training
* \param tree Trained tree of this iteration * \param tree Trained tree of this iteration
* \param curr_class Current class for multiclass training
*/ */
void UpdateScoreOutOfBag(const Tree* tree); void UpdateScoreOutOfBag(const Tree* tree, const int curr_class);
/*! /*!
* \brief calculate the object function * \brief calculate the object function
*/ */
void Boosting(); void Boosting();
/*! /*!
* \brief training one tree
* \return Trained tree of this iteration
*/
Tree* TrainOneTree();
/*!
* \brief updating score after tree was trained * \brief updating score after tree was trained
* \param tree Trained tree of this iteration * \param tree Trained tree of this iteration
* \param curr_class Current class for multiclass training
*/ */
void UpdateScore(const Tree* tree); void UpdateScore(const Tree* tree, const int curr_class);
/*! /*!
* \brief Print Metric result of current iteration * \brief Print metric result of current iteration
* \param iter Current interation * \param iter Current interation
*/ */
bool OutputMetric(int iter); bool OutputMetric(int iter);
/*!
int early_stopping_round_; * \brief Calculate feature importances
* \param last_iter Last tree use to calculate
*/
std::string FeatureImportance() const;
/*! \brief current iteration */
int iter_;
/*! \brief Pointer to training data */ /*! \brief Pointer to training data */
const Dataset* train_data_; const Dataset* train_data_;
/*! \brief Config of gbdt */ /*! \brief Config of gbdt */
const GBDTConfig* gbdt_config_; const GBDTConfig* gbdt_config_;
/*! \brief Tree learner, will use tihs class to learn trees */ /*! \brief Tree learner, will use this class to learn trees */
TreeLearner* tree_learner_; std::vector<TreeLearner*> tree_learner_;
/*! \brief Objective function */ /*! \brief Objective function */
const ObjectiveFunction* object_function_; const ObjectiveFunction* object_function_;
/*! \brief Store and update traning data's score */ /*! \brief Store and update training data's score */
ScoreUpdater* train_score_updater_; ScoreUpdater* train_score_updater_;
/*! \brief Metrics for training data */ /*! \brief Metrics for training data */
std::vector<const Metric*> training_metrics_; std::vector<const Metric*> training_metrics_;
...@@ -138,6 +173,8 @@ private: ...@@ -138,6 +173,8 @@ private:
std::vector<ScoreUpdater*> valid_score_updater_; std::vector<ScoreUpdater*> valid_score_updater_;
/*! \brief Metric for validation data */ /*! \brief Metric for validation data */
std::vector<std::vector<const Metric*>> valid_metrics_; std::vector<std::vector<const Metric*>> valid_metrics_;
/*! \brief Number of rounds for early stopping */
int early_stopping_round_;
/*! \brief Best score(s) for early stopping */ /*! \brief Best score(s) for early stopping */
std::vector<std::vector<int>> best_iter_; std::vector<std::vector<int>> best_iter_;
std::vector<std::vector<score_t>> best_score_; std::vector<std::vector<score_t>> best_score_;
...@@ -159,15 +196,21 @@ private: ...@@ -159,15 +196,21 @@ private:
data_size_t bag_data_cnt_; data_size_t bag_data_cnt_;
/*! \brief Number of traning data */ /*! \brief Number of traning data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
/*! \brief Random generator, used for bagging */ /*! \brief Random generator, used for bagging */
Random random_; Random random_;
/*! \brief The filename that the models will save to */
FILE * output_model_file;
/*! /*!
* \brief Sigmoid parameter, used for prediction. * \brief Sigmoid parameter, used for prediction.
* if > 0 meas output score will transform by sigmoid function * if > 0 meas output score will transform by sigmoid function
*/ */
double sigmoid_; float sigmoid_;
/*! \brief Index of label column */
data_size_t label_idx_;
/*! \brief Saved number of models */
int saved_model_size_ = -1;
/*! \brief File to write models */
std::ofstream model_output_file_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -18,13 +18,13 @@ public: ...@@ -18,13 +18,13 @@ public:
* \brief Constructor, will pass a const pointer of dataset * \brief Constructor, will pass a const pointer of dataset
* \param data This class will bind with this data set * \param data This class will bind with this data set
*/ */
explicit ScoreUpdater(const Dataset* data) explicit ScoreUpdater(const Dataset* data, int num_class)
:data_(data) { :data_(data) {
num_data_ = data->num_data(); num_data_ = data->num_data();
score_ = new score_t[num_data_]; score_ = new score_t[num_data_ * num_class];
// default start score is zero // default start score is zero
std::memset(score_, 0, sizeof(score_t)*num_data_); std::memset(score_, 0, sizeof(score_t) * num_data_ * num_class);
const score_t* init_score = data->metadata().init_score(); const float* init_score = data->metadata().init_score();
// if exists initial score, will start from it // if exists initial score, will start from it
if (init_score != nullptr) { if (init_score != nullptr) {
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
...@@ -41,8 +41,8 @@ public: ...@@ -41,8 +41,8 @@ public:
* Note: this function generally will be used on validation data too. * Note: this function generally will be used on validation data too.
* \param tree Trained tree model * \param tree Trained tree model
*/ */
inline void AddScore(const Tree* tree) { inline void AddScore(const Tree* tree, int curr_class) {
tree->AddPredictionToScore(data_, num_data_, score_); tree->AddPredictionToScore(data_, num_data_, score_ + curr_class * num_data_);
} }
/*! /*!
* \brief Adding prediction score, only used for training data. * \brief Adding prediction score, only used for training data.
...@@ -50,19 +50,19 @@ public: ...@@ -50,19 +50,19 @@ public:
* Based on which We can get prediction quckily. * Based on which We can get prediction quckily.
* \param tree_learner * \param tree_learner
*/ */
inline void AddScore(const TreeLearner* tree_learner) { inline void AddScore(const TreeLearner* tree_learner, int curr_class) {
tree_learner->AddPredictionToScore(score_); tree_learner->AddPredictionToScore(score_ + curr_class * num_data_);
} }
/*! /*!
* \brief Using tree model to get prediction number, then adding to scores for parts of data * \brief Using tree model to get prediction number, then adding to scores for parts of data
* Used for prediction of training out-of-bag data * Used for prediction of training out-of-bag data
* \param tree Trained tree model * \param tree Trained tree model
* \param data_indices Indices of data that want proccess to * \param data_indices Indices of data that will be proccessed
* \param data_cnt Number of data that want proccess to * \param data_cnt Number of data that will be proccessed
*/ */
inline void AddScore(const Tree* tree, const data_size_t* data_indices, inline void AddScore(const Tree* tree, const data_size_t* data_indices,
data_size_t data_cnt) { data_size_t data_cnt, int curr_class) {
tree->AddPredictionToScore(data_, data_indices, data_cnt, score_); tree->AddPredictionToScore(data_, data_indices, data_cnt, score_ + curr_class * num_data_);
} }
/*! \brief Pointer of score */ /*! \brief Pointer of score */
inline const score_t * score() { return score_; } inline const score_t * score() { return score_; }
...@@ -72,7 +72,7 @@ private: ...@@ -72,7 +72,7 @@ private:
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer of data set */ /*! \brief Pointer of data set */
const Dataset* data_; const Dataset* data_;
/*! \brief scores for data set */ /*! \brief Scores for data set */
score_t* score_; score_t* score_;
}; };
......
...@@ -23,7 +23,7 @@ BinMapper::BinMapper(const BinMapper& other) ...@@ -23,7 +23,7 @@ BinMapper::BinMapper(const BinMapper& other)
num_bin_ = other.num_bin_; num_bin_ = other.num_bin_;
is_trival_ = other.is_trival_; is_trival_ = other.is_trival_;
sparse_rate_ = other.sparse_rate_; sparse_rate_ = other.sparse_rate_;
bin_upper_bound_ = new double[num_bin_]; bin_upper_bound_ = new float[num_bin_];
for (int i = 0; i < num_bin_; ++i) { for (int i = 0; i < num_bin_; ++i) {
bin_upper_bound_[i] = other.bin_upper_bound_[i]; bin_upper_bound_[i] = other.bin_upper_bound_[i];
} }
...@@ -38,10 +38,10 @@ BinMapper::~BinMapper() { ...@@ -38,10 +38,10 @@ BinMapper::~BinMapper() {
delete[] bin_upper_bound_; delete[] bin_upper_bound_;
} }
void BinMapper::FindBin(std::vector<double>* values, int max_bin) { void BinMapper::FindBin(std::vector<float>* values, int max_bin) {
size_t sample_size = values->size(); size_t sample_size = values->size();
// find distinct_values first // find distinct_values first
double* distinct_values = new double[sample_size]; float* distinct_values = new float[sample_size];
int *counts = new int[sample_size]; int *counts = new int[sample_size];
int num_values = 1; int num_values = 1;
std::sort(values->begin(), values->end()); std::sort(values->begin(), values->end());
...@@ -61,19 +61,19 @@ void BinMapper::FindBin(std::vector<double>* values, int max_bin) { ...@@ -61,19 +61,19 @@ void BinMapper::FindBin(std::vector<double>* values, int max_bin) {
if (num_values <= max_bin) { if (num_values <= max_bin) {
// use distinct value is enough // use distinct value is enough
num_bin_ = num_values; num_bin_ = num_values;
bin_upper_bound_ = new double[num_values]; bin_upper_bound_ = new float[num_values];
for (int i = 0; i < num_values - 1; ++i) { for (int i = 0; i < num_values - 1; ++i) {
bin_upper_bound_[i] = (distinct_values[i] + distinct_values[i + 1]) / 2; bin_upper_bound_[i] = (distinct_values[i] + distinct_values[i + 1]) / 2;
} }
cnt_in_bin0 = counts[0]; cnt_in_bin0 = counts[0];
bin_upper_bound_[num_values - 1] = std::numeric_limits<double>::infinity(); bin_upper_bound_[num_values - 1] = std::numeric_limits<float>::infinity();
} else { } else {
// need find bins // need find bins
num_bin_ = max_bin; num_bin_ = max_bin;
bin_upper_bound_ = new double[max_bin]; bin_upper_bound_ = new float[max_bin];
double * bin_lower_bound = new double[max_bin]; float * bin_lower_bound = new float[max_bin];
// mean size for one bin // mean size for one bin
double mean_bin_size = sample_size / static_cast<double>(max_bin); float mean_bin_size = sample_size / static_cast<float>(max_bin);
int rest_sample_cnt = static_cast<int>(sample_size); int rest_sample_cnt = static_cast<int>(sample_size);
int cur_cnt_inbin = 0; int cur_cnt_inbin = 0;
int bin_cnt = 0; int bin_cnt = 0;
...@@ -88,24 +88,24 @@ void BinMapper::FindBin(std::vector<double>* values, int max_bin) { ...@@ -88,24 +88,24 @@ void BinMapper::FindBin(std::vector<double>* values, int max_bin) {
++bin_cnt; ++bin_cnt;
bin_lower_bound[bin_cnt] = distinct_values[i + 1]; bin_lower_bound[bin_cnt] = distinct_values[i + 1];
cur_cnt_inbin = 0; cur_cnt_inbin = 0;
mean_bin_size = rest_sample_cnt / static_cast<double>(max_bin - bin_cnt); mean_bin_size = rest_sample_cnt / static_cast<float>(max_bin - bin_cnt);
} }
} }
cur_cnt_inbin += counts[num_values - 1]; cur_cnt_inbin += counts[num_values - 1];
// update bin upper bound // update bin upper bound
for (int i = 0; i < bin_cnt; ++i) { for (int i = 0; i < bin_cnt; ++i) {
bin_upper_bound_[i] = (bin_upper_bound_[i] + bin_lower_bound[i + 1]) / 2.0; bin_upper_bound_[i] = (bin_upper_bound_[i] + bin_lower_bound[i + 1]) / 2.0f;
} }
// last bin upper bound // last bin upper bound
bin_upper_bound_[bin_cnt] = std::numeric_limits<double>::infinity(); bin_upper_bound_[bin_cnt] = std::numeric_limits<float>::infinity();
++bin_cnt; ++bin_cnt;
delete[] bin_lower_bound; delete[] bin_lower_bound;
// if no so much bin // if no so much bin
if (bin_cnt < max_bin) { if (bin_cnt < max_bin) {
// old bin data // old bin data
double * tmp_bin_upper_bound = bin_upper_bound_; float* tmp_bin_upper_bound = bin_upper_bound_;
num_bin_ = bin_cnt; num_bin_ = bin_cnt;
bin_upper_bound_ = new double[num_bin_]; bin_upper_bound_ = new float[num_bin_];
// copy back // copy back
for (int i = 0; i < num_bin_; ++i) { for (int i = 0; i < num_bin_; ++i) {
bin_upper_bound_[i] = tmp_bin_upper_bound[i]; bin_upper_bound_[i] = tmp_bin_upper_bound[i];
...@@ -123,7 +123,7 @@ void BinMapper::FindBin(std::vector<double>* values, int max_bin) { ...@@ -123,7 +123,7 @@ void BinMapper::FindBin(std::vector<double>* values, int max_bin) {
is_trival_ = false; is_trival_ = false;
} }
// calculate sparse rate // calculate sparse rate
sparse_rate_ = static_cast<double>(cnt_in_bin0) / static_cast<double>(sample_size); sparse_rate_ = static_cast<float>(cnt_in_bin0) / static_cast<float>(sample_size);
} }
...@@ -131,8 +131,8 @@ int BinMapper::SizeForSpecificBin(int bin) { ...@@ -131,8 +131,8 @@ int BinMapper::SizeForSpecificBin(int bin) {
int size = 0; int size = 0;
size += sizeof(int); size += sizeof(int);
size += sizeof(bool); size += sizeof(bool);
size += sizeof(double); size += sizeof(float);
size += bin * sizeof(double); size += bin * sizeof(float);
return size; return size;
} }
...@@ -143,7 +143,7 @@ void BinMapper::CopyTo(char * buffer) { ...@@ -143,7 +143,7 @@ void BinMapper::CopyTo(char * buffer) {
buffer += sizeof(is_trival_); buffer += sizeof(is_trival_);
std::memcpy(buffer, &sparse_rate_, sizeof(sparse_rate_)); std::memcpy(buffer, &sparse_rate_, sizeof(sparse_rate_));
buffer += sizeof(sparse_rate_); buffer += sizeof(sparse_rate_);
std::memcpy(buffer, bin_upper_bound_, num_bin_ * sizeof(double)); std::memcpy(buffer, bin_upper_bound_, num_bin_ * sizeof(float));
} }
void BinMapper::CopyFrom(const char * buffer) { void BinMapper::CopyFrom(const char * buffer) {
...@@ -154,19 +154,19 @@ void BinMapper::CopyFrom(const char * buffer) { ...@@ -154,19 +154,19 @@ void BinMapper::CopyFrom(const char * buffer) {
std::memcpy(&sparse_rate_, buffer, sizeof(sparse_rate_)); std::memcpy(&sparse_rate_, buffer, sizeof(sparse_rate_));
buffer += sizeof(sparse_rate_); buffer += sizeof(sparse_rate_);
if (bin_upper_bound_ != nullptr) { delete[] bin_upper_bound_; } if (bin_upper_bound_ != nullptr) { delete[] bin_upper_bound_; }
bin_upper_bound_ = new double[num_bin_]; bin_upper_bound_ = new float[num_bin_];
std::memcpy(bin_upper_bound_, buffer, num_bin_ * sizeof(double)); std::memcpy(bin_upper_bound_, buffer, num_bin_ * sizeof(float));
} }
void BinMapper::SaveBinaryToFile(FILE* file) const { void BinMapper::SaveBinaryToFile(FILE* file) const {
fwrite(&num_bin_, sizeof(num_bin_), 1, file); fwrite(&num_bin_, sizeof(num_bin_), 1, file);
fwrite(&is_trival_, sizeof(is_trival_), 1, file); fwrite(&is_trival_, sizeof(is_trival_), 1, file);
fwrite(&sparse_rate_, sizeof(sparse_rate_), 1, file); fwrite(&sparse_rate_, sizeof(sparse_rate_), 1, file);
fwrite(bin_upper_bound_, sizeof(double), num_bin_, file); fwrite(bin_upper_bound_, sizeof(float), num_bin_, file);
} }
size_t BinMapper::SizesInByte() const { size_t BinMapper::SizesInByte() const {
return sizeof(num_bin_) + sizeof(is_trival_) + sizeof(sparse_rate_) + sizeof(double) * num_bin_; return sizeof(num_bin_) + sizeof(is_trival_) + sizeof(sparse_rate_) + sizeof(float) * num_bin_;
} }
template class DenseBin<uint8_t>; template class DenseBin<uint8_t>;
...@@ -182,9 +182,9 @@ template class OrderedSparseBin<uint16_t>; ...@@ -182,9 +182,9 @@ template class OrderedSparseBin<uint16_t>;
template class OrderedSparseBin<uint32_t>; template class OrderedSparseBin<uint32_t>;
Bin* Bin::CreateBin(data_size_t num_data, int num_bin, double sparse_rate, bool is_enable_sparse, bool* is_sparse, int default_bin) { Bin* Bin::CreateBin(data_size_t num_data, int num_bin, float sparse_rate, bool is_enable_sparse, bool* is_sparse, int default_bin) {
// sparse threshold // sparse threshold
const double kSparseThreshold = 0.8; const float kSparseThreshold = 0.8f;
if (sparse_rate >= kSparseThreshold && is_enable_sparse) { if (sparse_rate >= kSparseThreshold && is_enable_sparse) {
*is_sparse = true; *is_sparse = true;
return CreateSparseBin(num_data, num_bin, default_bin); return CreateSparseBin(num_data, num_bin, default_bin);
......
...@@ -10,6 +10,26 @@ ...@@ -10,6 +10,26 @@
namespace LightGBM { namespace LightGBM {
void OverallConfig::LoadFromString(const char* str) {
std::unordered_map<std::string, std::string> params;
auto args = Common::Split(str, " \t\n\r");
for (auto arg : args) {
std::vector<std::string> tmp_strs = Common::Split(arg.c_str(), '=');
if (tmp_strs.size() == 2) {
std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
std::string value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
if (key.size() <= 0) {
continue;
}
params[key] = value;
} else {
Log::Error("Unknown parameter %s", arg.c_str());
}
}
ParameterAlias::KeyAliasTransform(&params);
Set(params);
}
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) { void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
// load main config types // load main config types
GetInt(params, "num_threads", &num_threads); GetInt(params, "num_threads", &num_threads);
...@@ -26,7 +46,6 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para ...@@ -26,7 +46,6 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
boosting_config = new GBDTConfig(); boosting_config = new GBDTConfig();
} }
// sub-config setup // sub-config setup
network_config.Set(params); network_config.Set(params);
io_config.Set(params); io_config.Set(params);
...@@ -113,6 +132,28 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin ...@@ -113,6 +132,28 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
void OverallConfig::CheckParamConflict() { void OverallConfig::CheckParamConflict() {
GBDTConfig* gbdt_config = dynamic_cast<GBDTConfig*>(boosting_config); GBDTConfig* gbdt_config = dynamic_cast<GBDTConfig*>(boosting_config);
// check if objective_type, metric_type, and num_class match
bool objective_type_multiclass = (objective_type == std::string("multiclass"));
int num_class_check = gbdt_config->num_class;
if (objective_type_multiclass){
if (num_class_check <= 1){
Log::Fatal("You should specify number of class(>=2) for multiclass training.");
}
}
else {
if (task_type == TaskType::kTrain && num_class_check != 1){
Log::Fatal("Number of class must be 1 for non-multiclass training.");
}
}
for (std::string metric_type : metric_types){
bool metric_type_multiclass = ( metric_type == std::string("multi_logloss") || metric_type == std::string("multi_error"));
if ((objective_type_multiclass && !metric_type_multiclass)
|| (!objective_type_multiclass && metric_type_multiclass)){
Log::Fatal("Objective and metrics don't match.");
}
}
if (network_config.num_machines > 1) { if (network_config.num_machines > 1) {
is_parallel = true; is_parallel = true;
} else { } else {
...@@ -159,48 +200,52 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -159,48 +200,52 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetString(params, "input_model", &input_model); GetString(params, "input_model", &input_model);
GetString(params, "output_result", &output_result); GetString(params, "output_result", &output_result);
GetString(params, "input_init_score", &input_init_score); GetString(params, "input_init_score", &input_init_score);
GetString(params, "log_file", &log_file);
std::string tmp_str = ""; std::string tmp_str = "";
if (GetString(params, "valid_data", &tmp_str)) { if (GetString(params, "valid_data", &tmp_str)) {
valid_data_filenames = Common::Split(tmp_str.c_str(), ','); valid_data_filenames = Common::Split(tmp_str.c_str(), ',');
} }
GetBool(params, "has_header", &has_header);
GetString(params, "label_column", &label_column);
GetString(params, "weight_column", &weight_column);
GetString(params, "group_column", &group_column);
GetString(params, "ignore_column", &ignore_column);
} }
void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) { void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetBool(params, "is_unbalance", &is_unbalance); GetBool(params, "is_unbalance", &is_unbalance);
GetDouble(params, "sigmoid", &sigmoid); GetFloat(params, "sigmoid", &sigmoid);
GetInt(params, "max_position", &max_position); GetInt(params, "max_position", &max_position);
CHECK(max_position > 0); CHECK(max_position > 0);
GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1);
std::string tmp_str = ""; std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) { if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToDoubleArray(tmp_str, ','); label_gain = Common::StringToFloatArray(tmp_str, ',');
} else { } else {
// label_gain = 2^i - 1, may overflow, so we use 31 here // label_gain = 2^i - 1, may overflow, so we use 31 here
const int max_label = 31; const int max_label = 31;
label_gain.push_back(0.0); label_gain.push_back(0.0f);
for (int i = 1; i < max_label; ++i) { for (int i = 1; i < max_label; ++i) {
label_gain.push_back((1 << i) - 1); label_gain.push_back(static_cast<float>((1 << i) - 1));
} }
} }
} }
void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) { void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "early_stopping_round", &early_stopping_round); GetFloat(params, "sigmoid", &sigmoid);
GetInt(params, "metric_freq", &output_freq); GetInt(params, "num_class", &num_class);
CHECK(output_freq >= 0); CHECK(num_class >= 1);
GetDouble(params, "sigmoid", &sigmoid);
GetBool(params, "is_training_metric", &is_provide_training_metric);
std::string tmp_str = ""; std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) { if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToDoubleArray(tmp_str, ','); label_gain = Common::StringToFloatArray(tmp_str, ',');
} else { } else {
// label_gain = 2^i - 1, may overflow, so we use 31 here // label_gain = 2^i - 1, may overflow, so we use 31 here
const int max_label = 31; const int max_label = 31;
label_gain.push_back(0.0); label_gain.push_back(0.0f);
for (int i = 1; i < max_label; ++i) { for (int i = 1; i < max_label; ++i) {
label_gain.push_back((1 << i) - 1); label_gain.push_back(static_cast<float>((1 << i) - 1));
} }
} }
if (GetString(params, "ndcg_eval_at", &tmp_str)) { if (GetString(params, "ndcg_eval_at", &tmp_str)) {
...@@ -220,14 +265,16 @@ void MetricConfig::Set(const std::unordered_map<std::string, std::string>& param ...@@ -220,14 +265,16 @@ void MetricConfig::Set(const std::unordered_map<std::string, std::string>& param
void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) { void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "min_data_in_leaf", &min_data_in_leaf); GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf); GetFloat(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
CHECK(min_sum_hessian_in_leaf > 1.0f || min_data_in_leaf > 0); CHECK(min_sum_hessian_in_leaf > 1.0f || min_data_in_leaf > 0);
GetInt(params, "num_leaves", &num_leaves); GetInt(params, "num_leaves", &num_leaves);
CHECK(num_leaves > 1); CHECK(num_leaves > 1);
GetInt(params, "feature_fraction_seed", &feature_fraction_seed); GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
GetDouble(params, "feature_fraction", &feature_fraction); GetFloat(params, "feature_fraction", &feature_fraction);
CHECK(feature_fraction > 0.0 && feature_fraction <= 1.0); CHECK(feature_fraction > 0.0f && feature_fraction <= 1.0f);
GetDouble(params, "histogram_pool_size", &histogram_pool_size); GetFloat(params, "histogram_pool_size", &histogram_pool_size);
GetInt(params, "max_depth", &max_depth);
CHECK(max_depth > 1 || max_depth < 0);
} }
...@@ -237,12 +284,17 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par ...@@ -237,12 +284,17 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
GetInt(params, "bagging_seed", &bagging_seed); GetInt(params, "bagging_seed", &bagging_seed);
GetInt(params, "bagging_freq", &bagging_freq); GetInt(params, "bagging_freq", &bagging_freq);
CHECK(bagging_freq >= 0); CHECK(bagging_freq >= 0);
GetDouble(params, "bagging_fraction", &bagging_fraction); GetFloat(params, "bagging_fraction", &bagging_fraction);
CHECK(bagging_fraction > 0.0 && bagging_fraction <= 1.0); CHECK(bagging_fraction > 0.0f && bagging_fraction <= 1.0f);
GetDouble(params, "learning_rate", &learning_rate); GetFloat(params, "learning_rate", &learning_rate);
CHECK(learning_rate > 0.0); CHECK(learning_rate > 0.0f);
GetInt(params, "early_stopping_round", &early_stopping_round); GetInt(params, "early_stopping_round", &early_stopping_round);
CHECK(early_stopping_round >= 0); CHECK(early_stopping_round >= 0);
GetInt(params, "metric_freq", &output_freq);
CHECK(output_freq >= 0);
GetBool(params, "is_training_metric", &is_provide_training_metric);
GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1);
} }
void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) { void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
......
...@@ -11,13 +11,14 @@ ...@@ -11,13 +11,14 @@
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <string> #include <string>
#include <sstream>
namespace LightGBM { namespace LightGBM {
Dataset::Dataset(const char* data_filename, const char* init_score_filename, Dataset::Dataset(const char* data_filename, const char* init_score_filename,
int max_bin, int random_seed, bool is_enable_sparse, const PredictFunction& predict_fun) const IOConfig& io_config, const PredictFunction& predict_fun)
:data_filename_(data_filename), random_(random_seed), :data_filename_(data_filename), random_(io_config.data_random_seed),
max_bin_(max_bin), is_enable_sparse_(is_enable_sparse), predict_fun_(predict_fun) { max_bin_(io_config.max_bin), is_enable_sparse_(io_config.is_enable_sparse), predict_fun_(predict_fun) {
CheckCanLoadFromBin(); CheckCanLoadFromBin();
if (is_loading_from_binfile_ && predict_fun != nullptr) { if (is_loading_from_binfile_ && predict_fun != nullptr) {
...@@ -28,15 +29,129 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -28,15 +29,129 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
if (!is_loading_from_binfile_) { if (!is_loading_from_binfile_) {
// load weight, query information and initilize score // load weight, query information and initilize score
metadata_.Init(data_filename, init_score_filename); metadata_.Init(data_filename, init_score_filename);
// create text reader
text_reader_ = new TextReader<data_size_t>(data_filename, io_config.has_header);
std::unordered_map<std::string, int> name2idx;
// get column names
if (io_config.has_header) {
std::string first_line = text_reader_->first_line();
feature_names_ = Common::Split(first_line.c_str(), "\t ,");
for (size_t i = 0; i < feature_names_.size(); ++i) {
name2idx[feature_names_[i]] = static_cast<int>(i);
}
}
std::string name_prefix("name:");
// load label idx
if (io_config.label_column.size() > 0) {
if (Common::StartsWith(io_config.label_column, name_prefix)) {
std::string name = io_config.label_column.substr(name_prefix.size());
if (name2idx.count(name) > 0) {
label_idx_ = name2idx[name];
Log::Info("use %s column as label", name.c_str());
} else {
Log::Fatal("cannot find label column: %s in data file", name.c_str());
}
} else {
if (!Common::AtoiAndCheck(io_config.label_column.c_str(), &label_idx_)) {
Log::Fatal("label_column is not a number, \
if you want to use column name, \
please add prefix \"name:\" before column name");
}
Log::Info("use %d-th column as label", label_idx_);
}
}
if (feature_names_.size() > 0) {
// erase label column name
feature_names_.erase(feature_names_.begin() + label_idx_);
}
// load ignore columns
if (io_config.ignore_column.size() > 0) {
if (Common::StartsWith(io_config.ignore_column, name_prefix)) {
std::string names = io_config.ignore_column.substr(name_prefix.size());
for (auto name : Common::Split(names.c_str(), ',')) {
if (name2idx.count(name) > 0) {
int tmp = name2idx[name];
// skip for label column
if (tmp > label_idx_) { tmp -= 1; }
ignore_features_.emplace(tmp);
} else {
Log::Fatal("cannot find column: %s in data file", name.c_str());
}
}
} else {
for (auto token : Common::Split(io_config.ignore_column.c_str(), ',')) {
int tmp = 0;
if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Log::Fatal("ignore_column is not a number, \
if you want to use column name, \
please add prefix \"name:\" before column name");
}
// skip for label column
if (tmp > label_idx_) { tmp -= 1; }
ignore_features_.emplace(tmp);
}
}
}
// load weight idx
if (io_config.weight_column.size() > 0) {
if (Common::StartsWith(io_config.weight_column, name_prefix)) {
std::string name = io_config.weight_column.substr(name_prefix.size());
if (name2idx.count(name) > 0) {
weight_idx_ = name2idx[name];
Log::Info("use %s column as weight", name.c_str());
} else {
Log::Fatal("cannot find weight column: %s in data file", name.c_str());
}
} else {
if (!Common::AtoiAndCheck(io_config.weight_column.c_str(), &weight_idx_)) {
Log::Fatal("weight_column is not a number, \
if you want to use column name, \
please add prefix \"name:\" before column name");
}
Log::Info("use %d-th column as weight", weight_idx_);
}
// skip for label column
if (weight_idx_ > label_idx_) {
weight_idx_ -= 1;
}
ignore_features_.emplace(weight_idx_);
}
if (io_config.group_column.size() > 0) {
if (Common::StartsWith(io_config.group_column, name_prefix)) {
std::string name = io_config.group_column.substr(name_prefix.size());
if (name2idx.count(name) > 0) {
group_idx_ = name2idx[name];
Log::Info("use %s column as group/query id", name.c_str());
} else {
Log::Fatal("cannot find group/query column: %s in data file", name.c_str());
}
} else {
if (!Common::AtoiAndCheck(io_config.group_column.c_str(), &group_idx_)) {
Log::Fatal("group_column is not a number, \
if you want to use column name, \
please add prefix \"name:\" before column name");
}
Log::Info("use %d-th column as group/query id", group_idx_);
}
// skip for label column
if (group_idx_ > label_idx_) {
group_idx_ -= 1;
}
ignore_features_.emplace(group_idx_);
}
// create text parser // create text parser
parser_ = Parser::CreateParser(data_filename_, 0, nullptr); parser_ = Parser::CreateParser(data_filename_, io_config.has_header, 0, label_idx_);
if (parser_ == nullptr) { if (parser_ == nullptr) {
Log::Fatal("Cannot recognising input data format, filename: %s", data_filename_); Log::Fatal("Cannot recognising input data format, filename: %s", data_filename_);
} }
// create text reader
text_reader_ = new TextReader<data_size_t>(data_filename);
} else { } else {
// only need to load initilize score, other meta data will load from bin flie // only need to load initilize score, other meta data will be loaded from bin flie
metadata_.Init(init_score_filename); metadata_.Init(init_score_filename);
Log::Info("Loading data set from binary file"); Log::Info("Loading data set from binary file");
parser_ = nullptr; parser_ = nullptr;
...@@ -159,10 +274,10 @@ void Dataset::SampleDataFromFile(int rank, int num_machines, bool is_pre_partiti ...@@ -159,10 +274,10 @@ void Dataset::SampleDataFromFile(int rank, int num_machines, bool is_pre_partiti
void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<std::string>& sample_data) { void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<std::string>& sample_data) {
// sample_values[i][j], means the value of j-th sample on i-th feature // sample_values[i][j], means the value of j-th sample on i-th feature
std::vector<std::vector<double>> sample_values; std::vector<std::vector<float>> sample_values;
// temp buffer for one line features and label // temp buffer for one line features and label
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, float>> oneline_features;
double label; float label;
for (size_t i = 0; i < sample_data.size(); ++i) { for (size_t i = 0; i < sample_data.size(); ++i) {
oneline_features.clear(); oneline_features.clear();
// parse features // parse features
...@@ -171,13 +286,13 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -171,13 +286,13 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
for (auto& feature_values : sample_values) { for (auto& feature_values : sample_values) {
feature_values.push_back(0.0); feature_values.push_back(0.0);
} }
for (std::pair<int, double>& inner_data : oneline_features) { for (std::pair<int, float>& inner_data : oneline_features) {
if (static_cast<size_t>(inner_data.first) >= sample_values.size()) { if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
// if need expand feature set // if need expand feature set
size_t need_size = inner_data.first - sample_values.size() + 1; size_t need_size = inner_data.first - sample_values.size() + 1;
for (size_t j = 0; j < need_size; ++j) { for (size_t j = 0; j < need_size; ++j) {
// push i+1 0 // push i+1 0
sample_values.emplace_back(i + 1, 0.0); sample_values.emplace_back(i + 1, 0.0f);
} }
} }
// edit the feature value // edit the feature value
...@@ -190,18 +305,40 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -190,18 +305,40 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
// -1 means doesn't use this feature // -1 means doesn't use this feature
used_feature_map_ = std::vector<int>(sample_values.size(), -1); used_feature_map_ = std::vector<int>(sample_values.size(), -1);
num_total_features_ = static_cast<int>(sample_values.size()); num_total_features_ = static_cast<int>(sample_values.size());
// check the range of label_idx, weight_idx and group_idx
CHECK(label_idx_ >= 0 && label_idx_ <= num_total_features_);
CHECK(weight_idx_ < 0 || weight_idx_ < num_total_features_);
CHECK(group_idx_ < 0 || group_idx_ < num_total_features_);
// fill feature_names_ if not header
if (feature_names_.size() <= 0) {
for (int i = 0; i < num_total_features_; ++i) {
std::stringstream str_buf;
str_buf << "Column_" << i;
feature_names_.push_back(str_buf.str());
}
}
// start find bins // start find bins
if (num_machines == 1) { if (num_machines == 1) {
std::vector<BinMapper*> bin_mappers(sample_values.size()); std::vector<BinMapper*> bin_mappers(sample_values.size());
// if only 1 machines, find bin locally // if only 1 machines, find bin locally
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) { for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
if (ignore_features_.count(i) > 0) {
bin_mappers[i] = nullptr;
continue;
}
bin_mappers[i] = new BinMapper(); bin_mappers[i] = new BinMapper();
bin_mappers[i]->FindBin(&sample_values[i], max_bin_); bin_mappers[i]->FindBin(&sample_values[i], max_bin_);
} }
for (size_t i = 0; i < sample_values.size(); ++i) { for (size_t i = 0; i < sample_values.size(); ++i) {
if (!bin_mappers[i]->is_trival()) { if (bin_mappers[i] == nullptr) {
Log::Error("Ignore Feature %s ", feature_names_[i].c_str());
}
else if (!bin_mappers[i]->is_trival()) {
// map real feature index to used feature index // map real feature index to used feature index
used_feature_map_[i] = static_cast<int>(features_.size()); used_feature_map_[i] = static_cast<int>(features_.size());
// push new feature // push new feature
...@@ -209,7 +346,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -209,7 +346,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
num_data_, is_enable_sparse_)); num_data_, is_enable_sparse_));
} else { } else {
// if feature is trival(only 1 bin), free spaces // if feature is trival(only 1 bin), free spaces
Log::Error("Feature %d only contains one value, will be ignored", i); Log::Error("Feature %s only contains one value, will be ignored", feature_names_[i].c_str());
delete bin_mappers[i]; delete bin_mappers[i];
} }
} }
...@@ -256,12 +393,17 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -256,12 +393,17 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
Network::Allgather(input_buffer, buffer_size, start, len, output_buffer); Network::Allgather(input_buffer, buffer_size, start, len, output_buffer);
// restore features bins from buffer // restore features bins from buffer
for (int i = 0; i < total_num_feature; ++i) { for (int i = 0; i < total_num_feature; ++i) {
if (ignore_features_.count(i) > 0) {
Log::Error("Ignore Feature %s ", feature_names_[i].c_str());
continue;
}
BinMapper* bin_mapper = new BinMapper(); BinMapper* bin_mapper = new BinMapper();
bin_mapper->CopyFrom(output_buffer + i * type_size); bin_mapper->CopyFrom(output_buffer + i * type_size);
if (!bin_mapper->is_trival()) { if (!bin_mapper->is_trival()) {
used_feature_map_[i] = static_cast<int>(features_.size()); used_feature_map_[i] = static_cast<int>(features_.size());
features_.push_back(new Feature(static_cast<int>(i), bin_mapper, num_data_, is_enable_sparse_)); features_.push_back(new Feature(static_cast<int>(i), bin_mapper, num_data_, is_enable_sparse_));
} else { } else {
Log::Error("Feature %s only contains one value, will be ignored", feature_names_[i].c_str());
delete bin_mapper; delete bin_mapper;
} }
} }
...@@ -276,6 +418,13 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -276,6 +418,13 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, bool use_two_round_loading) { void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, bool use_two_round_loading) {
// don't support query id in data file when training in parallel
if (num_machines > 1 && !is_pre_partition) {
if (group_idx_ > 0) {
Log::Fatal("Don't support query id in data file when training parallel without pre-partition. \
Please use an additional query file or pre-partition your data");
}
}
used_data_indices_.clear(); used_data_indices_.clear();
if (!is_loading_from_binfile_ ) { if (!is_loading_from_binfile_ ) {
if (!use_two_round_loading) { if (!use_two_round_loading) {
...@@ -287,7 +436,7 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b ...@@ -287,7 +436,7 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b
// construct feature bin mappers // construct feature bin mappers
ConstructBinMappers(rank, num_machines, sample_data); ConstructBinMappers(rank, num_machines, sample_data);
// initialize label // initialize label
metadata_.InitLabel(num_data_); metadata_.Init(num_data_, weight_idx_, group_idx_);
// extract features // extract features
ExtractFeaturesFromMemory(); ExtractFeaturesFromMemory();
} else { } else {
...@@ -297,7 +446,7 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b ...@@ -297,7 +446,7 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b
// construct feature bin mappers // construct feature bin mappers
ConstructBinMappers(rank, num_machines, sample_data); ConstructBinMappers(rank, num_machines, sample_data);
// initialize label // initialize label
metadata_.InitLabel(num_data_); metadata_.Init(num_data_, weight_idx_, group_idx_);
// extract features // extract features
ExtractFeaturesFromFile(); ExtractFeaturesFromFile();
...@@ -322,7 +471,7 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo ...@@ -322,7 +471,7 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo
// read data in memory // read data in memory
LoadDataToMemory(0, 1, false); LoadDataToMemory(0, 1, false);
// initialize label // initialize label
metadata_.InitLabel(num_data_); metadata_.Init(num_data_, weight_idx_, group_idx_);
features_.clear(); features_.clear();
// copy feature bin mapper data // copy feature bin mapper data
for (Feature* feature : train_set->features_) { for (Feature* feature : train_set->features_) {
...@@ -336,7 +485,7 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo ...@@ -336,7 +485,7 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo
// Get number of lines of data file // Get number of lines of data file
num_data_ = static_cast<data_size_t>(text_reader_->CountLine()); num_data_ = static_cast<data_size_t>(text_reader_->CountLine());
// initialize label // initialize label
metadata_.InitLabel(num_data_); metadata_.Init(num_data_, weight_idx_, group_idx_);
features_.clear(); features_.clear();
// copy feature bin mapper data // copy feature bin mapper data
for (Feature* feature : train_set->features_) { for (Feature* feature : train_set->features_) {
...@@ -358,8 +507,8 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo ...@@ -358,8 +507,8 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo
} }
void Dataset::ExtractFeaturesFromMemory() { void Dataset::ExtractFeaturesFromMemory() {
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, float>> oneline_features;
double tmp_label = 0.0; float tmp_label = 0.0f;
if (predict_fun_ == nullptr) { if (predict_fun_ == nullptr) {
// if doesn't need to prediction with initial model // if doesn't need to prediction with initial model
#pragma omp parallel for schedule(guided) private(oneline_features) firstprivate(tmp_label) #pragma omp parallel for schedule(guided) private(oneline_features) firstprivate(tmp_label)
...@@ -381,11 +530,18 @@ void Dataset::ExtractFeaturesFromMemory() { ...@@ -381,11 +530,18 @@ void Dataset::ExtractFeaturesFromMemory() {
// if is used feature // if is used feature
features_[feature_idx]->PushData(tid, i, inner_data.second); features_[feature_idx]->PushData(tid, i, inner_data.second);
} }
else {
if (inner_data.first == weight_idx_) {
metadata_.SetWeightAt(i, inner_data.second);
} else if (inner_data.first == group_idx_) {
metadata_.SetQueryAt(i, inner_data.second);
}
}
} }
} }
} else { } else {
// if need to prediction with initial model // if need to prediction with initial model
score_t* init_score = new score_t[num_data_]; float* init_score = new float[num_data_];
#pragma omp parallel for schedule(guided) private(oneline_features) firstprivate(tmp_label) #pragma omp parallel for schedule(guided) private(oneline_features) firstprivate(tmp_label)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
...@@ -393,7 +549,7 @@ void Dataset::ExtractFeaturesFromMemory() { ...@@ -393,7 +549,7 @@ void Dataset::ExtractFeaturesFromMemory() {
// parser // parser
parser_->ParseOneLine(text_reader_->Lines()[i].c_str(), &oneline_features, &tmp_label); parser_->ParseOneLine(text_reader_->Lines()[i].c_str(), &oneline_features, &tmp_label);
// set initial score // set initial score
init_score[i] = static_cast<score_t>(predict_fun_(oneline_features)); init_score[i] = static_cast<float>(predict_fun_(oneline_features));
// set label // set label
metadata_.SetLabelAt(i, tmp_label); metadata_.SetLabelAt(i, tmp_label);
// free processed line: // free processed line:
...@@ -407,14 +563,22 @@ void Dataset::ExtractFeaturesFromMemory() { ...@@ -407,14 +563,22 @@ void Dataset::ExtractFeaturesFromMemory() {
// if is used feature // if is used feature
features_[feature_idx]->PushData(tid, i, inner_data.second); features_[feature_idx]->PushData(tid, i, inner_data.second);
} }
else {
if (inner_data.first == weight_idx_) {
metadata_.SetWeightAt(i, inner_data.second);
} else if (inner_data.first == group_idx_) {
metadata_.SetQueryAt(i, inner_data.second);
}
}
} }
} }
// metadata_ will manage space of init_score // metadata_ will manage space of init_score
metadata_.SetInitScore(init_score); metadata_.SetInitScore(init_score, num_data_);
delete[] init_score;
} }
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < num_features_; i++) { for (int i = 0; i < num_features_; ++i) {
features_[i]->FinishLoad(); features_[i]->FinishLoad();
} }
// text data can be free after loaded feature values // text data can be free after loaded feature values
...@@ -423,24 +587,24 @@ void Dataset::ExtractFeaturesFromMemory() { ...@@ -423,24 +587,24 @@ void Dataset::ExtractFeaturesFromMemory() {
void Dataset::ExtractFeaturesFromFile() { void Dataset::ExtractFeaturesFromFile() {
score_t* init_score = nullptr; float* init_score = nullptr;
if (predict_fun_ != nullptr) { if (predict_fun_ != nullptr) {
init_score = new score_t[num_data_]; init_score = new float[num_data_];
} }
std::function<void(data_size_t, const std::vector<std::string>&)> process_fun = std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
[this, &init_score] [this, &init_score]
(data_size_t start_idx, const std::vector<std::string>& lines) { (data_size_t start_idx, const std::vector<std::string>& lines) {
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, float>> oneline_features;
double tmp_label = 0.0; float tmp_label = 0.0f;
#pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label) #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); i++) { for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
oneline_features.clear(); oneline_features.clear();
// parser // parser
parser_->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label); parser_->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
// set initial score // set initial score
if (init_score != nullptr) { if (init_score != nullptr) {
init_score[start_idx + i] = static_cast<score_t>(predict_fun_(oneline_features)); init_score[start_idx + i] = static_cast<float>(predict_fun_(oneline_features));
} }
// set label // set label
metadata_.SetLabelAt(start_idx + i, tmp_label); metadata_.SetLabelAt(start_idx + i, tmp_label);
...@@ -451,6 +615,13 @@ void Dataset::ExtractFeaturesFromFile() { ...@@ -451,6 +615,13 @@ void Dataset::ExtractFeaturesFromFile() {
// if is used feature // if is used feature
features_[feature_idx]->PushData(tid, start_idx + i, inner_data.second); features_[feature_idx]->PushData(tid, start_idx + i, inner_data.second);
} }
else {
if (inner_data.first == weight_idx_) {
metadata_.SetWeightAt(start_idx + i, inner_data.second);
} else if (inner_data.first == group_idx_) {
metadata_.SetQueryAt(start_idx + i, inner_data.second);
}
}
} }
} }
}; };
...@@ -465,11 +636,12 @@ void Dataset::ExtractFeaturesFromFile() { ...@@ -465,11 +636,12 @@ void Dataset::ExtractFeaturesFromFile() {
// metadata_ will manage space of init_score // metadata_ will manage space of init_score
if (init_score != nullptr) { if (init_score != nullptr) {
metadata_.SetInitScore(init_score); metadata_.SetInitScore(init_score, num_data_);
delete[] init_score;
} }
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < num_features_; i++) { for (int i = 0; i < num_features_; ++i) {
features_[i]->FinishLoad(); features_[i]->FinishLoad();
} }
} }
...@@ -613,7 +785,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -613,7 +785,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer)); size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer));
// re-allocmate space if not enough // re-allocate space if not enough
if (size_of_metadata > buffer_size) { if (size_of_metadata > buffer_size) {
delete[] buffer; delete[] buffer;
buffer_size = size_of_metadata; buffer_size = size_of_metadata;
...@@ -635,7 +807,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -635,7 +807,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
const data_size_t* query_boundaries = metadata_.query_boundaries(); const data_size_t* query_boundaries = metadata_.query_boundaries();
if (query_boundaries == nullptr) { if (query_boundaries == nullptr) {
// if not contain query file, minimal sample unit is one record // if not contain query file, minimal sample unit is one record
for (data_size_t i = 0; i < num_data_; i++) { for (data_size_t i = 0; i < num_data_; ++i) {
if (random_.NextInt(0, num_machines) == rank) { if (random_.NextInt(0, num_machines) == rank) {
used_data_indices_.push_back(i); used_data_indices_.push_back(i);
} }
...@@ -645,7 +817,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -645,7 +817,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
data_size_t num_queries = metadata_.num_queries(); data_size_t num_queries = metadata_.num_queries();
data_size_t qid = -1; data_size_t qid = -1;
bool is_query_used = false; bool is_query_used = false;
for (data_size_t i = 0; i < num_data_; i++) { for (data_size_t i = 0; i < num_data_; ++i) {
if (qid >= num_queries) { if (qid >= num_queries) {
Log::Fatal("current query is exceed the range of query file, please ensure your query file is correct"); Log::Fatal("current query is exceed the range of query file, please ensure your query file is correct");
} }
...@@ -673,7 +845,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -673,7 +845,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
Log::Fatal("Binary file format error at feature %d's size", i); Log::Fatal("Binary file format error at feature %d's size", i);
} }
size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer)); size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer));
// re-allocmate space if not enough // re-allocate space if not enough
if (size_of_feature > buffer_size) { if (size_of_feature > buffer_size) {
delete[] buffer; delete[] buffer;
buffer_size = size_of_feature; buffer_size = size_of_feature;
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
namespace LightGBM { namespace LightGBM {
/*! /*!
* \brief Used to Store bins for dense feature * \brief Used to store bins for dense feature
* Use template to reduce memory cost * Use template to reduce memory cost
*/ */
template <typename VAL_T> template <typename VAL_T>
......
...@@ -10,7 +10,7 @@ namespace LightGBM { ...@@ -10,7 +10,7 @@ namespace LightGBM {
Metadata::Metadata() Metadata::Metadata()
:label_(nullptr), label_int_(nullptr), weights_(nullptr), :label_(nullptr), label_int_(nullptr), weights_(nullptr),
query_boundaries_(nullptr), query_boundaries_(nullptr),
query_weights_(nullptr), init_score_(nullptr) { query_weights_(nullptr), init_score_(nullptr), queries_(nullptr){
} }
...@@ -36,12 +36,31 @@ Metadata::~Metadata() { ...@@ -36,12 +36,31 @@ Metadata::~Metadata() {
if (query_boundaries_ != nullptr) { delete[] query_boundaries_; } if (query_boundaries_ != nullptr) { delete[] query_boundaries_; }
if (query_weights_ != nullptr) { delete[] query_weights_; } if (query_weights_ != nullptr) { delete[] query_weights_; }
if (init_score_ != nullptr) { delete[] init_score_; } if (init_score_ != nullptr) { delete[] init_score_; }
if (queries_ != nullptr) { delete[] queries_; }
} }
void Metadata::InitLabel(data_size_t num_data) { void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
num_data_ = num_data; num_data_ = num_data;
label_ = new float[num_data_]; label_ = new float[num_data_];
if (weight_idx >= 0) {
if (weights_ != nullptr) {
Log::Info("using weight in data file, and ignore additional weight file");
delete[] weights_;
}
weights_ = new float[num_data_];
num_weights_ = num_data_;
memset(weights_, 0, sizeof(float) * num_data_);
}
if (query_idx >= 0) {
if (query_boundaries_ != nullptr) {
Log::Info("using query id in data file, and ignore additional query file");
delete[] query_boundaries_;
}
if (query_weights_ != nullptr) { delete[] query_weights_; }
queries_ = new data_size_t[num_data_];
memset(queries_, 0, sizeof(data_size_t) * num_data_);
}
} }
void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) { void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) {
...@@ -59,9 +78,35 @@ void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) { ...@@ -59,9 +78,35 @@ void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) {
void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data_size_t>& used_data_indices) { void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data_size_t>& used_data_indices) {
if (used_data_indices.size() == 0) { if (used_data_indices.size() == 0) {
if (queries_ != nullptr) {
// need convert query_id to boundaries
std::vector<data_size_t> tmp_buffer;
data_size_t last_qid = -1;
data_size_t cur_cnt = 0;
for (data_size_t i = 0; i < num_data_; ++i) {
if (last_qid != queries_[i]) {
if (cur_cnt > 0) {
tmp_buffer.push_back(cur_cnt);
}
cur_cnt = 0;
last_qid = queries_[i];
}
++cur_cnt;
}
tmp_buffer.push_back(cur_cnt);
query_boundaries_ = new data_size_t[tmp_buffer.size() + 1];
num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
query_boundaries_[0] = 0;
for (size_t i = 0; i < tmp_buffer.size(); ++i) {
query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
}
LoadQueryWeights();
delete[] queries_;
queries_ = nullptr;
}
// check weights // check weights
if (weights_ != nullptr && num_weights_ != num_data_) { if (weights_ != nullptr && num_weights_ != num_data_) {
Log::Error("Initial weight size doesn't equal to data, weights will be ignored"); Log::Fatal("Initial weight size doesn't equal to data");
delete[] weights_; delete[] weights_;
num_weights_ = 0; num_weights_ = 0;
weights_ = nullptr; weights_ = nullptr;
...@@ -69,7 +114,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -69,7 +114,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// check query boundries // check query boundries
if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_data_) { if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_data_) {
Log::Error("Initial query size doesn't equal to data, queies will be ignored"); Log::Fatal("Initial query size doesn't equal to data");
delete[] query_boundaries_; delete[] query_boundaries_;
num_queries_ = 0; num_queries_ = 0;
query_boundaries_ = nullptr; query_boundaries_ = nullptr;
...@@ -78,21 +123,22 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -78,21 +123,22 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// contain initial score file // contain initial score file
if (init_score_ != nullptr && num_init_score_ != num_data_) { if (init_score_ != nullptr && num_init_score_ != num_data_) {
delete[] init_score_; delete[] init_score_;
Log::Error("Initial score size doesn't equal to data, score file will be ignored"); Log::Fatal("Initial score size doesn't equal to data");
init_score_ = nullptr;
num_init_score_ = 0; num_init_score_ = 0;
} }
} else { } else {
data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size()); data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
// check weights // check weights
if (weights_ != nullptr && num_weights_ != num_all_data) { if (weights_ != nullptr && num_weights_ != num_all_data) {
Log::Error("Initial weights size doesn't equal to data, weights will be ignored"); Log::Fatal("Initial weights size doesn't equal to data");
delete[] weights_; delete[] weights_;
num_weights_ = 0; num_weights_ = 0;
weights_ = nullptr; weights_ = nullptr;
} }
// check query boundries // check query boundries
if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_all_data) { if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_all_data) {
Log::Error("Initial query size doesn't equal to data , queries will be ignored"); Log::Fatal("Initial query size doesn't equal to data");
delete[] query_boundaries_; delete[] query_boundaries_;
num_queries_ = 0; num_queries_ = 0;
query_boundaries_ = nullptr; query_boundaries_ = nullptr;
...@@ -100,9 +146,10 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -100,9 +146,10 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// contain initial score file // contain initial score file
if (init_score_ != nullptr && num_init_score_ != num_all_data) { if (init_score_ != nullptr && num_init_score_ != num_all_data) {
Log::Error("Initial score size doesn't equal to data , initial scores will be ignored"); Log::Fatal("Initial score size doesn't equal to data");
delete[] init_score_; delete[] init_score_;
num_init_score_ = 0; num_init_score_ = 0;
init_score_ = nullptr;
} }
// get local weights // get local weights
...@@ -131,10 +178,10 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -131,10 +178,10 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
used_query.push_back(qid); used_query.push_back(qid);
data_idx += len; data_idx += len;
} else { } else {
Log::Fatal("Data partition error, data didn't match queies"); Log::Fatal("Data partition error, data didn't match queries");
} }
} else { } else {
Log::Fatal("Data partition error, data didn't match queies"); Log::Fatal("Data partition error, data didn't match queries");
} }
} }
data_size_t * old_query_boundaries = query_boundaries_; data_size_t * old_query_boundaries = query_boundaries_;
...@@ -151,9 +198,9 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -151,9 +198,9 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// get local initial scores // get local initial scores
if (init_score_ != nullptr) { if (init_score_ != nullptr) {
score_t* old_scores = init_score_; float* old_scores = init_score_;
num_init_score_ = num_data_; num_init_score_ = num_data_;
init_score_ = new score_t[num_init_score_]; init_score_ = new float[num_init_score_];
for (size_t i = 0; i < used_data_indices.size(); ++i) { for (size_t i = 0; i < used_data_indices.size(); ++i) {
init_score_[i] = old_scores[used_data_indices[i]]; init_score_[i] = old_scores[used_data_indices[i]];
} }
...@@ -166,10 +213,16 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -166,10 +213,16 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
} }
void Metadata::SetInitScore(score_t* init_score) { void Metadata::SetInitScore(const float* init_score, data_size_t len) {
if (num_data_ != len) {
Log::Fatal("len of initial score is not same with #data");
}
if (init_score_ != nullptr) { delete[] init_score_; } if (init_score_ != nullptr) { delete[] init_score_; }
num_init_score_ = num_data_; num_init_score_ = num_data_;
init_score_ = init_score; init_score_ = new float[num_init_score_];
for (data_size_t i = 0; i < num_init_score_; ++i) {
init_score_[i] = init_score[i];
}
} }
void Metadata::LoadWeights() { void Metadata::LoadWeights() {
...@@ -177,7 +230,7 @@ void Metadata::LoadWeights() { ...@@ -177,7 +230,7 @@ void Metadata::LoadWeights() {
std::string weight_filename(data_filename_); std::string weight_filename(data_filename_);
// default weight file name // default weight file name
weight_filename.append(".weight"); weight_filename.append(".weight");
TextReader<size_t> reader(weight_filename.c_str()); TextReader<size_t> reader(weight_filename.c_str(), false);
reader.ReadAllLines(); reader.ReadAllLines();
if (reader.Lines().size() <= 0) { if (reader.Lines().size() <= 0) {
return; return;
...@@ -186,25 +239,25 @@ void Metadata::LoadWeights() { ...@@ -186,25 +239,25 @@ void Metadata::LoadWeights() {
num_weights_ = static_cast<data_size_t>(reader.Lines().size()); num_weights_ = static_cast<data_size_t>(reader.Lines().size());
weights_ = new float[num_weights_]; weights_ = new float[num_weights_];
for (data_size_t i = 0; i < num_weights_; ++i) { for (data_size_t i = 0; i < num_weights_; ++i) {
double tmp_weight = 0.0f; float tmp_weight = 0.0f;
Common::Atof(reader.Lines()[i].c_str(), &tmp_weight); Common::Atof(reader.Lines()[i].c_str(), &tmp_weight);
weights_[i] = static_cast<float>(tmp_weight); weights_[i] = tmp_weight;
} }
} }
void Metadata::LoadInitialScore() { void Metadata::LoadInitialScore() {
num_init_score_ = 0; num_init_score_ = 0;
if (init_score_filename_[0] == '\0') { return; } if (init_score_filename_[0] == '\0') { return; }
TextReader<size_t> reader(init_score_filename_); TextReader<size_t> reader(init_score_filename_, false);
reader.ReadAllLines(); reader.ReadAllLines();
Log::Info("Start loading initial scores"); Log::Info("Start loading initial scores");
num_init_score_ = static_cast<data_size_t>(reader.Lines().size()); num_init_score_ = static_cast<data_size_t>(reader.Lines().size());
init_score_ = new score_t[num_init_score_]; init_score_ = new float[num_init_score_];
double tmp = 0.0f; float tmp = 0.0f;
for (data_size_t i = 0; i < num_init_score_; ++i) { for (data_size_t i = 0; i < num_init_score_; ++i) {
Common::Atof(reader.Lines()[i].c_str(), &tmp); Common::Atof(reader.Lines()[i].c_str(), &tmp);
init_score_[i] = static_cast<score_t>(tmp); init_score_[i] = tmp;
} }
} }
...@@ -213,7 +266,7 @@ void Metadata::LoadQueryBoundaries() { ...@@ -213,7 +266,7 @@ void Metadata::LoadQueryBoundaries() {
std::string query_filename(data_filename_); std::string query_filename(data_filename_);
// default query file name // default query file name
query_filename.append(".query"); query_filename.append(".query");
TextReader<size_t> reader(query_filename.c_str()); TextReader<size_t> reader(query_filename.c_str(), false);
reader.ReadAllLines(); reader.ReadAllLines();
if (reader.Lines().size() <= 0) { if (reader.Lines().size() <= 0) {
return; return;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <functional>
namespace LightGBM { namespace LightGBM {
...@@ -20,44 +21,65 @@ void GetStatistic(const char* str, int* comma_cnt, int* tab_cnt, int* colon_cnt) ...@@ -20,44 +21,65 @@ void GetStatistic(const char* str, int* comma_cnt, int* tab_cnt, int* colon_cnt)
} }
} }
bool CheckHasLabelForLibsvm(std::string& str) { int GetLabelIdxForLibsvm(std::string& str, int num_features, int label_idx) {
if (num_features <= 0) {
return label_idx;
}
str = Common::Trim(str); str = Common::Trim(str);
auto pos_space = str.find_first_of(" \f\n\r\t\v"); auto pos_space = str.find_first_of(" \f\n\r\t\v");
auto pos_colon = str.find_first_of(":"); auto pos_colon = str.find_first_of(":");
if (pos_colon == std::string::npos || pos_colon > pos_space) { if (pos_space == std::string::npos || pos_space < pos_colon) {
return true; return label_idx;
} else { } else {
return false; return -1;
} }
} }
bool CheckHasLabelForTSV(std::string& str, int num_features) { int GetLabelIdxForTSV(std::string& str, int num_features, int label_idx) {
if (num_features <= 0) {
return label_idx;
}
str = Common::Trim(str); str = Common::Trim(str);
auto tokens = Common::Split(str.c_str(), '\t'); auto tokens = Common::Split(str.c_str(), '\t');
if (static_cast<int>(tokens.size()) == num_features) { if (static_cast<int>(tokens.size()) == num_features) {
return false; return -1;
} else { } else {
return true; return label_idx;
} }
} }
bool CheckHasLabelForCSV(std::string& str, int num_features) { int GetLabelIdxForCSV(std::string& str, int num_features, int label_idx) {
if (num_features <= 0) {
return label_idx;
}
str = Common::Trim(str); str = Common::Trim(str);
auto tokens = Common::Split(str.c_str(), ','); auto tokens = Common::Split(str.c_str(), ',');
if (static_cast<int>(tokens.size()) == num_features) { if (static_cast<int>(tokens.size()) == num_features) {
return false; return -1;
} else { } else {
return true; return label_idx;
} }
} }
Parser* Parser::CreateParser(const char* filename, int num_features, bool* has_label) { enum DataType {
INVALID,
CSV,
TSV,
LIBSVM
};
Parser* Parser::CreateParser(const char* filename, bool has_header, int num_features, int label_idx) {
std::ifstream tmp_file; std::ifstream tmp_file;
tmp_file.open(filename); tmp_file.open(filename);
if (!tmp_file.is_open()) { if (!tmp_file.is_open()) {
Log::Fatal("Data file: %s doesn't exist", filename); Log::Fatal("Data file: %s doesn't exist", filename);
} }
std::string line1, line2; std::string line1, line2;
if (has_header) {
if (!tmp_file.eof()) {
std::getline(tmp_file, line1);
}
}
if (!tmp_file.eof()) { if (!tmp_file.eof()) {
std::getline(tmp_file, line1); std::getline(tmp_file, line1);
} else { } else {
...@@ -75,43 +97,47 @@ Parser* Parser::CreateParser(const char* filename, int num_features, bool* has_l ...@@ -75,43 +97,47 @@ Parser* Parser::CreateParser(const char* filename, int num_features, bool* has_l
// Get some statistic from 2 line // Get some statistic from 2 line
GetStatistic(line1.c_str(), &comma_cnt, &tab_cnt, &colon_cnt); GetStatistic(line1.c_str(), &comma_cnt, &tab_cnt, &colon_cnt);
GetStatistic(line2.c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2); GetStatistic(line2.c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2);
Parser* ret = nullptr;
DataType type = DataType::INVALID;
if (line2.size() == 0) { if (line2.size() == 0) {
// if only have one line on file // if only have one line on file
if (colon_cnt > 0) { if (colon_cnt > 0) {
ret = new LibSVMParser(); type = DataType::LIBSVM;
if (num_features > 0 && has_label != nullptr) {
*has_label = CheckHasLabelForLibsvm(line1);
}
} else if (tab_cnt > 0) { } else if (tab_cnt > 0) {
ret = new TSVParser(); type = DataType::TSV;
if (num_features > 0 && has_label != nullptr) {
*has_label = CheckHasLabelForTSV(line1, num_features);
}
} else if (comma_cnt > 0) { } else if (comma_cnt > 0) {
ret = new CSVParser(); type = DataType::CSV;
if (num_features > 0 && has_label != nullptr) {
*has_label = CheckHasLabelForCSV(line1, num_features);
}
} }
} else { } else {
if (colon_cnt > 0 || colon_cnt2 > 0) { if (colon_cnt > 0 || colon_cnt2 > 0) {
ret = new LibSVMParser(); type = DataType::LIBSVM;
if (num_features > 0 && has_label != nullptr) { } else if (tab_cnt == tab_cnt2 && tab_cnt > 0) {
*has_label = CheckHasLabelForLibsvm(line1); type = DataType::TSV;
} else if (comma_cnt == comma_cnt2 && comma_cnt > 0) {
type = DataType::CSV;
} }
} }
else if (tab_cnt == tab_cnt2 && tab_cnt > 0) { if (type == DataType::INVALID) {
ret = new TSVParser(); Log::Fatal("Unkown format of training data");
if (num_features > 0 && has_label != nullptr) {
*has_label = CheckHasLabelForTSV(line1, num_features);
} }
} else if (comma_cnt == comma_cnt2 && comma_cnt > 0) { Parser* ret = nullptr;
ret = new CSVParser(); if (type == DataType::LIBSVM) {
if (num_features > 0 && has_label != nullptr) { label_idx = GetLabelIdxForLibsvm(line1, num_features, label_idx);
*has_label = CheckHasLabelForCSV(line1, num_features); ret = new LibSVMParser(label_idx);
} }
else if (type == DataType::TSV) {
label_idx = GetLabelIdxForTSV(line1, num_features, label_idx);
ret = new TSVParser(label_idx);
} }
else if (type == DataType::CSV) {
label_idx = GetLabelIdxForCSV(line1, num_features, label_idx);
ret = new CSVParser(label_idx);
}
if (label_idx < 0) {
Log::Info("Data file: %s doesn't contain label column", filename);
} }
return ret; return ret;
} }
......
...@@ -14,14 +14,23 @@ namespace LightGBM { ...@@ -14,14 +14,23 @@ namespace LightGBM {
class CSVParser: public Parser { class CSVParser: public Parser {
public: public:
explicit CSVParser(int label_idx)
:label_idx_(label_idx) {
}
inline void ParseOneLine(const char* str, inline void ParseOneLine(const char* str,
std::vector<std::pair<int, double>>* out_features) const override { std::vector<std::pair<int, float>>* out_features, float* out_label) const override {
int idx = 0; int idx = 0;
double val = 0.0; float val = 0.0f;
int bias = 0;
*out_label = 0.0f;
while (*str != '\0') { while (*str != '\0') {
str = Common::Atof(str, &val); str = Common::Atof(str, &val);
if (fabs(val) > 1e-10) { if (idx == label_idx_) {
out_features->emplace_back(idx, val); *out_label = val;
bias = -1;
}
else if (fabs(val) > 1e-10) {
out_features->emplace_back(idx + bias, val);
} }
++idx; ++idx;
if (*str == ',') { if (*str == ',') {
...@@ -31,28 +40,27 @@ public: ...@@ -31,28 +40,27 @@ public:
} }
} }
} }
inline void ParseOneLine(const char* str, std::vector<std::pair<int, double>>* out_features, private:
double* out_label) const override { int label_idx_ = 0;
// first column is label
str = Common::Atof(str, out_label);
if (*str == ',') {
++str;
} else if (*str != '\0') {
Log::Fatal("input format error, should be CSV");
}
return ParseOneLine(str, out_features);
}
}; };
class TSVParser: public Parser { class TSVParser: public Parser {
public: public:
inline void ParseOneLine(const char* str, std::vector<std::pair<int, double>>* out_features) const override { explicit TSVParser(int label_idx)
:label_idx_(label_idx) {
}
inline void ParseOneLine(const char* str,
std::vector<std::pair<int, float>>* out_features, float* out_label) const override {
int idx = 0; int idx = 0;
double val = 0.0; float val = 0.0f;
int bias = 0;
while (*str != '\0') { while (*str != '\0') {
str = Common::Atof(str, &val); str = Common::Atof(str, &val);
if (fabs(val) > 1e-10) { if (idx == label_idx_) {
out_features->emplace_back(idx, val); *out_label = val;
bias = -1;
} else if (fabs(val) > 1e-10) {
out_features->emplace_back(idx + bias, val);
} }
++idx; ++idx;
if (*str == '\t') { if (*str == '\t') {
...@@ -62,24 +70,27 @@ public: ...@@ -62,24 +70,27 @@ public:
} }
} }
} }
inline void ParseOneLine(const char* str, std::vector<std::pair<int, double>>* out_features, private:
double* out_label) const override { int label_idx_ = 0;
// first column is label
str = Common::Atof(str, out_label);
if (*str == '\t') {
++str;
} else if (*str != '\0') {
Log::Fatal("input format error, should be TSV");
}
return ParseOneLine(str, out_features);
}
}; };
class LibSVMParser: public Parser { class LibSVMParser: public Parser {
public: public:
inline void ParseOneLine(const char* str, std::vector<std::pair<int, double>>* out_features) const override { explicit LibSVMParser(int label_idx)
:label_idx_(label_idx) {
if (label_idx > 0) {
Log::Fatal("label should be the first column in Libsvm file");
}
}
inline void ParseOneLine(const char* str,
std::vector<std::pair<int, float>>* out_features, float* out_label) const override {
int idx = 0; int idx = 0;
double val = 0.0; float val = 0.0f;
if (label_idx_ == 0) {
str = Common::Atof(str, &val);
*out_label = val;
str = Common::SkipSpaceAndTab(str);
}
while (*str != '\0') { while (*str != '\0') {
str = Common::Atoi(str, &idx); str = Common::Atoi(str, &idx);
str = Common::SkipSpaceAndTab(str); str = Common::SkipSpaceAndTab(str);
...@@ -93,13 +104,9 @@ public: ...@@ -93,13 +104,9 @@ public:
str = Common::SkipSpaceAndTab(str); str = Common::SkipSpaceAndTab(str);
} }
} }
inline void ParseOneLine(const char* str, std::vector<std::pair<int, double>>* out_features, private:
double* out_label) const override { int label_idx_ = 0;
// first column is label
str = Common::Atof(str, out_label);
str = Common::SkipSpaceAndTab(str);
return ParseOneLine(str, out_features);
}
}; };
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_IO_PARSER_HPP_ #endif // LightGBM_IO_PARSER_HPP_
...@@ -64,7 +64,7 @@ public: ...@@ -64,7 +64,7 @@ public:
data_size_t cur_pos = fast_pair.second; data_size_t cur_pos = fast_pair.second;
data_size_t lte_count = 0; data_size_t lte_count = 0;
data_size_t gt_count = 0; data_size_t gt_count = 0;
for (data_size_t i = 0; i < num_data; i++) { for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i]; const data_size_t idx = data_indices[i];
while (cur_pos < idx && j < num_vals_) { while (cur_pos < idx && j < num_vals_) {
++j; ++j;
...@@ -92,12 +92,12 @@ public: ...@@ -92,12 +92,12 @@ public:
void FinishLoad() override { void FinishLoad() override {
// get total non zero size // get total non zero size
size_t non_zero_size = 0; size_t non_zero_size = 0;
for (size_t i = 0; i < push_buffers_.size(); i++) { for (size_t i = 0; i < push_buffers_.size(); ++i) {
non_zero_size += push_buffers_[i].size(); non_zero_size += push_buffers_[i].size();
} }
// merge // merge
non_zero_pair_.reserve(non_zero_size); non_zero_pair_.reserve(non_zero_size);
for (size_t i = 0; i < push_buffers_.size(); i++) { for (size_t i = 0; i < push_buffers_.size(); ++i) {
non_zero_pair_.insert(non_zero_pair_.end(), push_buffers_[i].begin(), push_buffers_[i].end()); non_zero_pair_.insert(non_zero_pair_.end(), push_buffers_[i].begin(), push_buffers_[i].end());
push_buffers_[i].clear(); push_buffers_[i].clear();
push_buffers_[i].shrink_to_fit(); push_buffers_[i].shrink_to_fit();
...@@ -122,7 +122,7 @@ public: ...@@ -122,7 +122,7 @@ public:
// transform to delta array // transform to delta array
const uint8_t kMaxDelta = 255; const uint8_t kMaxDelta = 255;
data_size_t last_idx = 0; data_size_t last_idx = 0;
for (size_t i = 0; i < non_zero_pair.size(); i++) { for (size_t i = 0; i < non_zero_pair.size(); ++i) {
const data_size_t cur_idx = non_zero_pair[i].first; const data_size_t cur_idx = non_zero_pair[i].first;
const VAL_T bin = non_zero_pair[i].second; const VAL_T bin = non_zero_pair[i].second;
data_size_t cur_delta = cur_idx - last_idx; data_size_t cur_delta = cur_idx - last_idx;
...@@ -198,7 +198,7 @@ public: ...@@ -198,7 +198,7 @@ public:
delta_.clear(); delta_.clear();
vals_.clear(); vals_.clear();
num_vals_ = tmp_num_vals; num_vals_ = tmp_num_vals;
for (data_size_t i = 0; i < num_vals_; i++) { for (data_size_t i = 0; i < num_vals_; ++i) {
delta_.push_back(tmp_delta[i]); delta_.push_back(tmp_delta[i]);
vals_.push_back(tmp_vals[i]); vals_.push_back(tmp_vals[i]);
} }
......
...@@ -23,11 +23,14 @@ Tree::Tree(int max_leaves) ...@@ -23,11 +23,14 @@ Tree::Tree(int max_leaves)
split_feature_ = new int[max_leaves_ - 1]; split_feature_ = new int[max_leaves_ - 1];
split_feature_real_ = new int[max_leaves_ - 1]; split_feature_real_ = new int[max_leaves_ - 1];
threshold_in_bin_ = new unsigned int[max_leaves_ - 1]; threshold_in_bin_ = new unsigned int[max_leaves_ - 1];
threshold_ = new double[max_leaves_ - 1]; threshold_ = new float[max_leaves_ - 1];
split_gain_ = new double[max_leaves_ - 1]; split_gain_ = new float[max_leaves_ - 1];
leaf_parent_ = new int[max_leaves_]; leaf_parent_ = new int[max_leaves_];
leaf_value_ = new score_t[max_leaves_]; leaf_value_ = new float[max_leaves_];
leaf_depth_ = new int[max_leaves_];
// root is in the depth 1
leaf_depth_[0] = 1;
num_leaves_ = 1; num_leaves_ = 1;
leaf_parent_[0] = -1; leaf_parent_[0] = -1;
} }
...@@ -41,10 +44,11 @@ Tree::~Tree() { ...@@ -41,10 +44,11 @@ Tree::~Tree() {
if (threshold_ != nullptr) { delete[] threshold_; } if (threshold_ != nullptr) { delete[] threshold_; }
if (split_gain_ != nullptr) { delete[] split_gain_; } if (split_gain_ != nullptr) { delete[] split_gain_; }
if (leaf_value_ != nullptr) { delete[] leaf_value_; } if (leaf_value_ != nullptr) { delete[] leaf_value_; }
if (leaf_depth_ != nullptr) { delete[] leaf_depth_; }
} }
int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feature, int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feature,
double threshold, score_t left_value, score_t right_value, double gain) { float threshold, float left_value, float right_value, float gain) {
int new_node_idx = num_leaves_ - 1; int new_node_idx = num_leaves_ - 1;
// update parent info // update parent info
int parent = leaf_parent_[leaf]; int parent = leaf_parent_[leaf];
...@@ -70,19 +74,21 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat ...@@ -70,19 +74,21 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat
leaf_parent_[num_leaves_] = new_node_idx; leaf_parent_[num_leaves_] = new_node_idx;
leaf_value_[leaf] = left_value; leaf_value_[leaf] = left_value;
leaf_value_[num_leaves_] = right_value; leaf_value_[num_leaves_] = right_value;
// update leaf depth
leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
leaf_depth_[leaf]++;
++num_leaves_; ++num_leaves_;
return num_leaves_ - 1; return num_leaves_ - 1;
} }
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, score_t* score) const { void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, score_t* score) const {
Threading::For<data_size_t>(0, num_data, [this, data, score](int, data_size_t start, data_size_t end) { Threading::For<data_size_t>(0, num_data, [this, data, score](int, data_size_t start, data_size_t end) {
std::vector<BinIterator*> iterators; std::vector<BinIterator*> iterators;
for (int i = 0; i < data->num_features(); i++) { for (int i = 0; i < data->num_features(); ++i) {
iterators.push_back(data->FeatureAt(i)->bin_data()->GetIterator(start)); iterators.push_back(data->FeatureAt(i)->bin_data()->GetIterator(start));
} }
for (data_size_t i = start; i < end; i++) { for (data_size_t i = start; i < end; ++i) {
score[i] += leaf_value_[GetLeaf(iterators, i)]; score[i] += leaf_value_[GetLeaf(iterators, i)];
} }
}); });
...@@ -93,10 +99,10 @@ void Tree::AddPredictionToScore(const Dataset* data, const data_size_t* used_dat ...@@ -93,10 +99,10 @@ void Tree::AddPredictionToScore(const Dataset* data, const data_size_t* used_dat
Threading::For<data_size_t>(0, num_data, Threading::For<data_size_t>(0, num_data,
[this, data, used_data_indices, score](int, data_size_t start, data_size_t end) { [this, data, used_data_indices, score](int, data_size_t start, data_size_t end) {
std::vector<BinIterator*> iterators; std::vector<BinIterator*> iterators;
for (int i = 0; i < data->num_features(); i++) { for (int i = 0; i < data->num_features(); ++i) {
iterators.push_back(data->FeatureAt(i)->bin_data()->GetIterator(used_data_indices[start])); iterators.push_back(data->FeatureAt(i)->bin_data()->GetIterator(used_data_indices[start]));
} }
for (data_size_t i = start; i < end; i++) { for (data_size_t i = start; i < end; ++i) {
score[used_data_indices[i]] += leaf_value_[GetLeaf(iterators, used_data_indices[i])]; score[used_data_indices[i]] += leaf_value_[GetLeaf(iterators, used_data_indices[i])];
} }
}); });
...@@ -108,9 +114,9 @@ std::string Tree::ToString() { ...@@ -108,9 +114,9 @@ std::string Tree::ToString() {
ss << "split_feature=" ss << "split_feature="
<< Common::ArrayToString<int>(split_feature_real_, num_leaves_ - 1, ' ') << std::endl; << Common::ArrayToString<int>(split_feature_real_, num_leaves_ - 1, ' ') << std::endl;
ss << "split_gain=" ss << "split_gain="
<< Common::ArrayToString<double>(split_gain_, num_leaves_ - 1, ' ') << std::endl; << Common::ArrayToString<float>(split_gain_, num_leaves_ - 1, ' ') << std::endl;
ss << "threshold=" ss << "threshold="
<< Common::ArrayToString<double>(threshold_, num_leaves_ - 1, ' ') << std::endl; << Common::ArrayToString<float>(threshold_, num_leaves_ - 1, ' ') << std::endl;
ss << "left_child=" ss << "left_child="
<< Common::ArrayToString<int>(left_child_, num_leaves_ - 1, ' ') << std::endl; << Common::ArrayToString<int>(left_child_, num_leaves_ - 1, ' ') << std::endl;
ss << "right_child=" ss << "right_child="
...@@ -118,7 +124,7 @@ std::string Tree::ToString() { ...@@ -118,7 +124,7 @@ std::string Tree::ToString() {
ss << "leaf_parent=" ss << "leaf_parent="
<< Common::ArrayToString<int>(leaf_parent_, num_leaves_, ' ') << std::endl; << Common::ArrayToString<int>(leaf_parent_, num_leaves_, ' ') << std::endl;
ss << "leaf_value=" ss << "leaf_value="
<< Common::ArrayToString<score_t>(leaf_value_, num_leaves_, ' ') << std::endl; << Common::ArrayToString<float>(leaf_value_, num_leaves_, ' ') << std::endl;
ss << std::endl; ss << std::endl;
return ss.str(); return ss.str();
} }
...@@ -148,19 +154,20 @@ Tree::Tree(const std::string& str) { ...@@ -148,19 +154,20 @@ Tree::Tree(const std::string& str) {
left_child_ = new int[num_leaves_ - 1]; left_child_ = new int[num_leaves_ - 1];
right_child_ = new int[num_leaves_ - 1]; right_child_ = new int[num_leaves_ - 1];
split_feature_real_ = new int[num_leaves_ - 1]; split_feature_real_ = new int[num_leaves_ - 1];
threshold_ = new double[num_leaves_ - 1]; threshold_ = new float[num_leaves_ - 1];
split_gain_ = new double[num_leaves_ - 1]; split_gain_ = new float[num_leaves_ - 1];
leaf_parent_ = new int[num_leaves_]; leaf_parent_ = new int[num_leaves_];
leaf_value_ = new score_t[num_leaves_]; leaf_value_ = new float[num_leaves_];
split_feature_ = nullptr; split_feature_ = nullptr;
threshold_in_bin_ = nullptr; threshold_in_bin_ = nullptr;
leaf_depth_ = nullptr;
Common::StringToIntArray(key_vals["split_feature"], ' ', Common::StringToIntArray(key_vals["split_feature"], ' ',
num_leaves_ - 1, split_feature_real_); num_leaves_ - 1, split_feature_real_);
Common::StringToDoubleArray(key_vals["split_gain"], ' ', Common::StringToFloatArray(key_vals["split_gain"], ' ',
num_leaves_ - 1, split_gain_); num_leaves_ - 1, split_gain_);
Common::StringToDoubleArray(key_vals["threshold"], ' ', Common::StringToFloatArray(key_vals["threshold"], ' ',
num_leaves_ - 1, threshold_); num_leaves_ - 1, threshold_);
Common::StringToIntArray(key_vals["left_child"], ' ', Common::StringToIntArray(key_vals["left_child"], ' ',
num_leaves_ - 1, left_child_); num_leaves_ - 1, left_child_);
...@@ -168,7 +175,7 @@ Tree::Tree(const std::string& str) { ...@@ -168,7 +175,7 @@ Tree::Tree(const std::string& str) {
num_leaves_ - 1, right_child_); num_leaves_ - 1, right_child_);
Common::StringToIntArray(key_vals["leaf_parent"], ' ', Common::StringToIntArray(key_vals["leaf_parent"], ' ',
num_leaves_ , leaf_parent_); num_leaves_ , leaf_parent_);
Common::StringToDoubleArray(key_vals["leaf_value"], ' ', Common::StringToFloatArray(key_vals["leaf_value"], ' ',
num_leaves_ , leaf_value_); num_leaves_ , leaf_value_);
} }
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include <sstream>
namespace LightGBM { namespace LightGBM {
...@@ -18,9 +19,6 @@ template<typename PointWiseLossCalculator> ...@@ -18,9 +19,6 @@ template<typename PointWiseLossCalculator>
class BinaryMetric: public Metric { class BinaryMetric: public Metric {
public: public:
explicit BinaryMetric(const MetricConfig& config) { explicit BinaryMetric(const MetricConfig& config) {
early_stopping_round_ = config.early_stopping_round;
output_freq_ = config.output_freq;
the_bigger_the_better = false;
sigmoid_ = static_cast<score_t>(config.sigmoid); sigmoid_ = static_cast<score_t>(config.sigmoid);
if (sigmoid_ <= 0.0f) { if (sigmoid_ <= 0.0f) {
Log::Fatal("Sigmoid param %f should greater than zero", sigmoid_); Log::Fatal("Sigmoid param %f should greater than zero", sigmoid_);
...@@ -32,7 +30,9 @@ public: ...@@ -32,7 +30,9 @@ public:
} }
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
name = test_name; std::stringstream str_buf;
str_buf << test_name << "'s " << PointWiseLossCalculator::Name();
name_ = str_buf.str();
num_data_ = num_data; num_data_ = num_data;
// get label // get label
label_ = metadata.label(); label_ = metadata.label();
...@@ -41,7 +41,7 @@ public: ...@@ -41,7 +41,7 @@ public:
weights_ = metadata.weights(); weights_ = metadata.weights();
if (weights_ == nullptr) { if (weights_ == nullptr) {
sum_weights_ = static_cast<double>(num_data_); sum_weights_ = static_cast<float>(num_data_);
} else { } else {
sum_weights_ = 0.0f; sum_weights_ = 0.0f;
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
...@@ -50,11 +50,18 @@ public: ...@@ -50,11 +50,18 @@ public:
} }
} }
score_t PrintAndGetLoss(int iter, const score_t* score) const override { const char* GetName() const override {
return name_.c_str();
}
bool is_bigger_better() const override {
return false;
}
std::vector<float> Eval(const score_t* score) const override {
score_t sum_loss = 0.0f; score_t sum_loss = 0.0f;
if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
if (weights_ == nullptr) { if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// sigmoid transform // sigmoid transform
score_t prob = 1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * score[i])); score_t prob = 1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * score[i]));
...@@ -62,7 +69,7 @@ public: ...@@ -62,7 +69,7 @@ public:
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob); sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob);
} }
} else { } else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// sigmoid transform // sigmoid transform
score_t prob = 1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * score[i])); score_t prob = 1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * score[i]));
...@@ -71,17 +78,10 @@ public: ...@@ -71,17 +78,10 @@ public:
} }
} }
score_t loss = sum_loss / sum_weights_; score_t loss = sum_loss / sum_weights_;
if (output_freq_ > 0 && iter % output_freq_ == 0){ return std::vector<float>(1, static_cast<float>(loss));
Log::Info("Iteration:%d, %s's %s: %f", iter, name, PointWiseLossCalculator::Name(), loss);
}
return loss;
}
return 0.0f;
} }
private: private:
/*! \brief Output frequently */
int output_freq_;
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */
...@@ -89,9 +89,9 @@ private: ...@@ -89,9 +89,9 @@ private:
/*! \brief Pointer of weighs */ /*! \brief Pointer of weighs */
const float* weights_; const float* weights_;
/*! \brief Sum weights */ /*! \brief Sum weights */
double sum_weights_; float sum_weights_;
/*! \brief Name of test set */ /*! \brief Name of test set */
const char* name; std::string name_;
/*! \brief Sigmoid parameter */ /*! \brief Sigmoid parameter */
score_t sigmoid_; score_t sigmoid_;
}; };
...@@ -145,17 +145,26 @@ public: ...@@ -145,17 +145,26 @@ public:
*/ */
class AUCMetric: public Metric { class AUCMetric: public Metric {
public: public:
explicit AUCMetric(const MetricConfig& config) { explicit AUCMetric(const MetricConfig&) {
early_stopping_round_ = config.early_stopping_round;
output_freq_ = config.output_freq;
the_bigger_the_better = true;
} }
virtual ~AUCMetric() { virtual ~AUCMetric() {
} }
const char* GetName() const override {
return name_.c_str();
}
bool is_bigger_better() const override {
return true;
}
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
name = test_name; std::stringstream str_buf;
str_buf << test_name << "'s AUC";
name_ = str_buf.str();
num_data_ = num_data; num_data_ = num_data;
// get label // get label
label_ = metadata.label(); label_ = metadata.label();
...@@ -163,7 +172,7 @@ public: ...@@ -163,7 +172,7 @@ public:
weights_ = metadata.weights(); weights_ = metadata.weights();
if (weights_ == nullptr) { if (weights_ == nullptr) {
sum_weights_ = static_cast<double>(num_data_); sum_weights_ = static_cast<float>(num_data_);
} else { } else {
sum_weights_ = 0.0f; sum_weights_ = 0.0f;
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
...@@ -172,8 +181,7 @@ public: ...@@ -172,8 +181,7 @@ public:
} }
} }
score_t PrintAndGetLoss(int iter, const score_t* score) const override { std::vector<float> Eval(const score_t* score) const override {
if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
// get indices sorted by score, descent order // get indices sorted by score, descent order
std::vector<data_size_t> sorted_idx; std::vector<data_size_t> sorted_idx;
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
...@@ -181,15 +189,15 @@ public: ...@@ -181,15 +189,15 @@ public:
} }
std::sort(sorted_idx.begin(), sorted_idx.end(), [score](data_size_t a, data_size_t b) {return score[a] > score[b]; }); std::sort(sorted_idx.begin(), sorted_idx.end(), [score](data_size_t a, data_size_t b) {return score[a] > score[b]; });
// temp sum of postive label // temp sum of postive label
double cur_pos = 0.0; score_t cur_pos = 0.0f;
// total sum of postive label // total sum of postive label
double sum_pos = 0.0; score_t sum_pos = 0.0f;
// accumlate of auc // accumlate of auc
double accum = 0.0; score_t accum = 0.0f;
// temp sum of negative label // temp sum of negative label
double cur_neg = 0.0; score_t cur_neg = 0.0f;
score_t threshold = score[sorted_idx[0]]; score_t threshold = score[sorted_idx[0]];
if (weights_ == nullptr) { // not weights if (weights_ == nullptr) { // no weights
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
const float cur_label = label_[sorted_idx[i]]; const float cur_label = label_[sorted_idx[i]];
const score_t cur_score = score[sorted_idx[i]]; const score_t cur_score = score[sorted_idx[i]];
...@@ -197,12 +205,12 @@ public: ...@@ -197,12 +205,12 @@ public:
if (cur_score != threshold) { if (cur_score != threshold) {
threshold = cur_score; threshold = cur_score;
// accmulate // accmulate
accum += cur_neg*(cur_pos * 0.5 + sum_pos); accum += cur_neg*(cur_pos * 0.5f + sum_pos);
sum_pos += cur_pos; sum_pos += cur_pos;
// reset // reset
cur_neg = cur_pos = 0.0; cur_neg = cur_pos = 0.0f;
} }
cur_neg += 1.0 - cur_label; cur_neg += 1.0f - cur_label;
cur_pos += cur_label; cur_pos += cur_label;
} }
} else { // has weights } else { // has weights
...@@ -214,32 +222,25 @@ public: ...@@ -214,32 +222,25 @@ public:
if (cur_score != threshold) { if (cur_score != threshold) {
threshold = cur_score; threshold = cur_score;
// accmulate // accmulate
accum += cur_neg*(cur_pos * 0.5 + sum_pos); accum += cur_neg*(cur_pos * 0.5f + sum_pos);
sum_pos += cur_pos; sum_pos += cur_pos;
// reset // reset
cur_neg = cur_pos = 0.0; cur_neg = cur_pos = 0.0f;
} }
cur_neg += (1.0 - cur_label)*cur_weight; cur_neg += (1.0f - cur_label)*cur_weight;
cur_pos += cur_label*cur_weight; cur_pos += cur_label*cur_weight;
} }
} }
accum += cur_neg*(cur_pos * 0.5 + sum_pos); accum += cur_neg*(cur_pos * 0.5f + sum_pos);
sum_pos += cur_pos; sum_pos += cur_pos;
double auc = 1.0; score_t auc = 1.0f;
if (sum_pos > 0.0f && sum_pos != sum_weights_) { if (sum_pos > 0.0f && sum_pos != sum_weights_) {
auc = accum / (sum_pos *(sum_weights_ - sum_pos)); auc = accum / (sum_pos *(sum_weights_ - sum_pos));
} }
if (output_freq_ > 0 && iter % output_freq_ == 0){ return std::vector<float>(1, static_cast<float>(auc));
Log::Info("Iteration:%d, %s's %s: %f", iter, name, "auc", auc);
}
return auc;
}
return 0.0f;
} }
private: private:
/*! \brief Output frequently */
int output_freq_;
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */
...@@ -247,9 +248,9 @@ private: ...@@ -247,9 +248,9 @@ private:
/*! \brief Pointer of weighs */ /*! \brief Pointer of weighs */
const float* weights_; const float* weights_;
/*! \brief Sum weights */ /*! \brief Sum weights */
double sum_weights_; float sum_weights_;
/*! \brief Name of test set */ /*! \brief Name of test set */
const char* name; std::string name_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -11,23 +11,23 @@ namespace LightGBM { ...@@ -11,23 +11,23 @@ namespace LightGBM {
/*! \brief Declaration for some static members */ /*! \brief Declaration for some static members */
bool DCGCalculator::is_inited_ = false; bool DCGCalculator::is_inited_ = false;
std::vector<double> DCGCalculator::label_gain_; std::vector<float> DCGCalculator::label_gain_;
std::vector<double> DCGCalculator::discount_; std::vector<float> DCGCalculator::discount_;
const data_size_t DCGCalculator::kMaxPosition = 10000; const data_size_t DCGCalculator::kMaxPosition = 10000;
void DCGCalculator::Init(std::vector<double> input_label_gain) { void DCGCalculator::Init(std::vector<float> input_label_gain) {
// only inited one time // only inited one time
if (is_inited_) { return; } if (is_inited_) { return; }
label_gain_ = input_label_gain; label_gain_ = input_label_gain;
discount_.clear(); discount_.clear();
for (data_size_t i = 0; i < kMaxPosition; ++i) { for (data_size_t i = 0; i < kMaxPosition; ++i) {
discount_.emplace_back(1.0 / std::log2(2.0 + i)); discount_.emplace_back(1.0f / std::log2(2.0f + i));
} }
is_inited_ = true; is_inited_ = true;
} }
double DCGCalculator::CalMaxDCGAtK(data_size_t k, const float* label, data_size_t num_data) { float DCGCalculator::CalMaxDCGAtK(data_size_t k, const float* label, data_size_t num_data) {
double ret = 0.0; float ret = 0.0f;
// counts for all labels // counts for all labels
std::vector<data_size_t> label_cnt(label_gain_.size(), 0); std::vector<data_size_t> label_cnt(label_gain_.size(), 0);
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
...@@ -53,14 +53,14 @@ double DCGCalculator::CalMaxDCGAtK(data_size_t k, const float* label, data_size_ ...@@ -53,14 +53,14 @@ double DCGCalculator::CalMaxDCGAtK(data_size_t k, const float* label, data_size_
void DCGCalculator::CalMaxDCG(const std::vector<data_size_t>& ks, void DCGCalculator::CalMaxDCG(const std::vector<data_size_t>& ks,
const float* label, const float* label,
data_size_t num_data, data_size_t num_data,
std::vector<double>* out) { std::vector<float>* out) {
std::vector<data_size_t> label_cnt(label_gain_.size(), 0); std::vector<data_size_t> label_cnt(label_gain_.size(), 0);
// counts for all labels // counts for all labels
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
if (static_cast<size_t>(label[i]) >= label_cnt.size()) { Log::Fatal("label excel %d", label[i]); } if (static_cast<size_t>(label[i]) >= label_cnt.size()) { Log::Fatal("label excel %d", label[i]); }
++label_cnt[static_cast<int>(label[i])]; ++label_cnt[static_cast<int>(label[i])];
} }
double cur_result = 0.0; float cur_result = 0.0f;
data_size_t cur_left = 0; data_size_t cur_left = 0;
size_t top_label = label_gain_.size() - 1; size_t top_label = label_gain_.size() - 1;
// calculate k Max DCG by one pass // calculate k Max DCG by one pass
...@@ -83,7 +83,7 @@ void DCGCalculator::CalMaxDCG(const std::vector<data_size_t>& ks, ...@@ -83,7 +83,7 @@ void DCGCalculator::CalMaxDCG(const std::vector<data_size_t>& ks,
} }
double DCGCalculator::CalDCGAtK(data_size_t k, const float* label, float DCGCalculator::CalDCGAtK(data_size_t k, const float* label,
const score_t* score, data_size_t num_data) { const score_t* score, data_size_t num_data) {
// get sorted indices by score // get sorted indices by score
std::vector<data_size_t> sorted_idx; std::vector<data_size_t> sorted_idx;
...@@ -94,7 +94,7 @@ double DCGCalculator::CalDCGAtK(data_size_t k, const float* label, ...@@ -94,7 +94,7 @@ double DCGCalculator::CalDCGAtK(data_size_t k, const float* label,
[score](data_size_t a, data_size_t b) {return score[a] > score[b]; }); [score](data_size_t a, data_size_t b) {return score[a] > score[b]; });
if (k > num_data) { k = num_data; } if (k > num_data) { k = num_data; }
double dcg = 0.0; float dcg = 0.0f;
// calculate dcg // calculate dcg
for (data_size_t i = 0; i < k; ++i) { for (data_size_t i = 0; i < k; ++i) {
data_size_t idx = sorted_idx[i]; data_size_t idx = sorted_idx[i];
...@@ -104,7 +104,7 @@ double DCGCalculator::CalDCGAtK(data_size_t k, const float* label, ...@@ -104,7 +104,7 @@ double DCGCalculator::CalDCGAtK(data_size_t k, const float* label,
} }
void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const float* label, void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const float* label,
const score_t * score, data_size_t num_data, std::vector<double>* out) { const score_t * score, data_size_t num_data, std::vector<float>* out) {
// get sorted indices by score // get sorted indices by score
std::vector<data_size_t> sorted_idx; std::vector<data_size_t> sorted_idx;
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
...@@ -113,7 +113,7 @@ void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const float* labe ...@@ -113,7 +113,7 @@ void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const float* labe
std::sort(sorted_idx.begin(), sorted_idx.end(), std::sort(sorted_idx.begin(), sorted_idx.end(),
[score](data_size_t a, data_size_t b) {return score[a] > score[b]; }); [score](data_size_t a, data_size_t b) {return score[a] > score[b]; });
double cur_result = 0.0; float cur_result = 0.0f;
data_size_t cur_left = 0; data_size_t cur_left = 0;
// calculate multi dcg by one pass // calculate multi dcg by one pass
for (size_t i = 0; i < ks.size(); ++i) { for (size_t i = 0; i < ks.size(); ++i) {
......
...@@ -2,22 +2,27 @@ ...@@ -2,22 +2,27 @@
#include "regression_metric.hpp" #include "regression_metric.hpp"
#include "binary_metric.hpp" #include "binary_metric.hpp"
#include "rank_metric.hpp" #include "rank_metric.hpp"
#include "multiclass_metric.hpp"
namespace LightGBM { namespace LightGBM {
Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config) { Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config) {
if (type == "l2") { if (type == std::string("l2")) {
return new L2Metric(config); return new L2Metric(config);
} else if (type == "l1") { } else if (type == std::string("l1")) {
return new L1Metric(config); return new L1Metric(config);
} else if (type == "binary_logloss") { } else if (type == std::string("binary_logloss")) {
return new BinaryLoglossMetric(config); return new BinaryLoglossMetric(config);
} else if (type == "binary_error") { } else if (type == std::string("binary_error")) {
return new BinaryErrorMetric(config); return new BinaryErrorMetric(config);
} else if (type == "auc") { } else if (type == std::string("auc")) {
return new AUCMetric(config); return new AUCMetric(config);
} else if (type == "ndcg") { } else if (type == std::string("ndcg")) {
return new NDCGMetric(config); return new NDCGMetric(config);
} else if (type == std::string("multi_logloss")) {
return new MultiLoglossMetric(config);
} else if (type == std::string("multi_error")) {
return new MultiErrorMetric(config);
} }
return nullptr; return nullptr;
} }
......
#ifndef LIGHTGBM_METRIC_MULTICLASS_METRIC_HPP_
#define LIGHTGBM_METRIC_MULTICLASS_METRIC_HPP_
#include <LightGBM/utils/log.h>
#include <LightGBM/metric.h>
#include <cmath>
namespace LightGBM {
/*!
* \brief Metric for multiclass task.
* Use static class "PointWiseLossCalculator" to calculate loss point-wise
*/
template<typename PointWiseLossCalculator>
class MulticlassMetric: public Metric {
public:
explicit MulticlassMetric(const MetricConfig& config) {
num_class_ = config.num_class;
}
virtual ~MulticlassMetric() {
}
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf;
str_buf << test_name << "'s " << PointWiseLossCalculator::Name();
name_ = str_buf.str();
num_data_ = num_data;
// get label
label_ = metadata.label();
// get weights
weights_ = metadata.weights();
if (weights_ == nullptr) {
sum_weights_ = static_cast<float>(num_data_);
} else {
sum_weights_ = 0.0f;
for (data_size_t i = 0; i < num_data_; ++i) {
sum_weights_ += weights_[i];
}
}
}
const char* GetName() const override {
return name_.c_str();
}
bool is_bigger_better() const override {
return false;
}
std::vector<float> Eval(const score_t* score) const override {
score_t sum_loss = 0.0;
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<float> rec(num_class_);
for (int k = 0; k < num_class_; ++k) {
rec[k] = static_cast<float>(score[k * num_data_ + i]);
}
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec);
}
} else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<float> rec(num_class_);
for (int k = 0; k < num_class_; ++k) {
rec[k] = static_cast<float>(score[k * num_data_ + i]);
}
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i];
}
}
score_t loss = sum_loss / sum_weights_;
return std::vector<float>(1, static_cast<float>(loss));
}
private:
/*! \brief Output frequency */
int output_freq_;
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
/*! \brief Pointer of label */
const float* label_;
/*! \brief Pointer of weighs */
const float* weights_;
/*! \brief Sum weights */
float sum_weights_;
/*! \brief Name of this test set */
std::string name_;
};
/*! \brief L2 loss for multiclass task */
class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> {
public:
explicit MultiErrorMetric(const MetricConfig& config) :MulticlassMetric<MultiErrorMetric>(config) {}
inline static score_t LossOnPoint(float label, std::vector<float> score) {
size_t k = static_cast<size_t>(label);
for (size_t i = 0; i < score.size(); ++i){
if (i != k && score[i] > score[k]) {
return 0.0f;
}
}
return 1.0f;
}
inline static const char* Name() {
return "multi error";
}
};
/*! \brief Logloss for multiclass task */
class MultiLoglossMetric: public MulticlassMetric<MultiLoglossMetric> {
public:
explicit MultiLoglossMetric(const MetricConfig& config) :MulticlassMetric<MultiLoglossMetric>(config) {}
inline static score_t LossOnPoint(float label, std::vector<float> score) {
size_t k = static_cast<size_t>(label);
Common::Softmax(&score);
if (score[k] > kEpsilon) {
return -std::log(score[k]);
} else {
return -std::log(kEpsilon);
}
}
inline static const char* Name() {
return "multi logloss";
}
};
} // namespace LightGBM
#endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment