Commit 7b8bb4f2 authored by Guolin Ke's avatar Guolin Ke
Browse files

some refactor for predict and metric logic

parent 33344088
...@@ -46,25 +46,23 @@ public: ...@@ -46,25 +46,23 @@ public:
/*! \brief Training logic */ /*! \brief Training logic */
virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0; virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0;
/*! \brief Get eval result */ virtual std::vector<double> GetEvalAt(int data_idx) const = 0;
virtual std::vector<std::string> EvalCurrent(bool is_eval_train) const = 0 ;
/*! \brief Get prediction result */ virtual const score_t* GetScoreAt(int data_idx, data_size_t* out_len) const = 0;
virtual const std::vector<const score_t*> PredictCurrent(bool is_predict_train) const = 0;
/*! /*!
* \brief Prediction for one record, not sigmoid transform * \brief Prediction for one record, not sigmoid transform
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result for this record * \return Prediction result for this record
*/ */
virtual double PredictRaw(const double* feature_values) const = 0; virtual std::vector<double> PredictRaw(const double* feature_values) const = 0;
/*! /*!
* \brief Prediction for one record, sigmoid transformation will be used if needed * \brief Prediction for one record, sigmoid transformation will be used if needed
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result for this record * \return Prediction result for this record
*/ */
virtual double Predict(const double* feature_values) const = 0; virtual std::vector<double> Predict(const double* feature_values) const = 0;
/*! /*!
* \brief Predtion for one record with leaf index * \brief Predtion for one record with leaf index
...@@ -73,14 +71,7 @@ public: ...@@ -73,14 +71,7 @@ public:
*/ */
virtual std::vector<int> PredictLeafIndex( virtual std::vector<int> PredictLeafIndex(
const double* feature_values) const = 0; const double* feature_values) const = 0;
/*!
* \brief Predtion for multiclass classification
* \param feature_values Feature value on this record
* \return Prediction result, num_class numbers per line
*/
virtual std::vector<double> PredictMulticlass(const double* value) const = 0;
/*! /*!
* \brief save model to file * \brief save model to file
*/ */
......
...@@ -206,8 +206,8 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, ...@@ -206,8 +206,8 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
* \prama out handle of created Booster * \prama out handle of created Booster
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterCreate(DatesetHandle train_data, DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
DatesetHandle valid_datas[], const DatesetHandle valid_datas[],
const char* valid_names[], const char* valid_names[],
int n_valid_datas, int n_valid_datas,
const char* parameters, const char* parameters,
...@@ -248,8 +248,8 @@ DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished); ...@@ -248,8 +248,8 @@ DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished);
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
float* grad, const float* grad,
float* hess, const float* hess,
int* is_finished); int* is_finished);
/*! /*!
...@@ -261,23 +261,20 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, ...@@ -261,23 +261,20 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
*/ */
DllExport int LGBM_BoosterEval(BoosterHandle handle, DllExport int LGBM_BoosterEval(BoosterHandle handle,
int data, int data,
const char** out_result); uint64_t* out_len,
double* out_results);
/*! /*!
* \brief make prediction for training data and validation datas * \brief make prediction for training data and validation datas
this can be used to support customized eval function this can be used to support customized eval function / and gradients calculation
* \param handle handle * \param handle handle
* \param data 0:training data, 1: 1st valid data, 2:2nd valid data ... * \param data 0:training data, 1: 1st valid data, 2:2nd valid data ...
* \param predict_type
* 0:raw score
* 1:with sigmoid/softmax transform(if needed)
* 2:leaf index
* \param out_result used to set a pointer to array * \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterPredict(BoosterHandle handle, DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
int data, int data,
int predict_type, uint64_t* out_len,
const float** out_result); const float** out_result);
/*! /*!
...@@ -307,7 +304,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -307,7 +304,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
uint64_t num_col, uint64_t num_col,
int predict_type, int predict_type,
uint64_t n_used_trees, uint64_t n_used_trees,
const double** out_result); double* out_result);
/*! /*!
* \brief make prediction for an new data set * \brief make prediction for an new data set
...@@ -336,7 +333,7 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -336,7 +333,7 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
uint64_t num_row, uint64_t num_row,
int predict_type, int predict_type,
uint64_t n_used_trees, uint64_t n_used_trees,
const double** out_result); double* out_result);
/*! /*!
* \brief make prediction for an new data set * \brief make prediction for an new data set
...@@ -360,17 +357,17 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -360,17 +357,17 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
int32_t ncol, int32_t ncol,
int predict_type, int predict_type,
uint64_t n_used_trees, uint64_t n_used_trees,
const double** out_result); double* out_result);
/*! /*!
* \brief save model into file * \brief save model into file
* \param handle handle * \param handle handle
* \param is_finished 1 means finised * \param num_used_model
* \param filename file name * \param filename file name
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterSaveModel(BoosterHandle handle, DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
int is_finished, int num_used_model,
const char* filename); const char* filename);
#endif // LIGHTGBM_C_API_H_ #endif // LIGHTGBM_C_API_H_
...@@ -101,7 +101,7 @@ public: ...@@ -101,7 +101,7 @@ public:
bool is_save_binary_file = false; bool is_save_binary_file = false;
bool enable_load_from_binary_file = true; bool enable_load_from_binary_file = true;
int bin_construct_sample_cnt = 50000; int bin_construct_sample_cnt = 50000;
bool is_sigmoid = true; bool is_raw_score = true;
bool has_header = false; bool has_header = false;
/*! \brief Index or column name of label, default is the first column /*! \brief Index or column name of label, default is the first column
......
...@@ -246,7 +246,7 @@ public: ...@@ -246,7 +246,7 @@ public:
Dataset(); Dataset();
explicit Dataset(data_size_t num_data); explicit Dataset(data_size_t num_data, int num_class);
/*! \brief Destructor */ /*! \brief Destructor */
~Dataset(); ~Dataset();
......
...@@ -27,9 +27,9 @@ public: ...@@ -27,9 +27,9 @@ public:
virtual void Init(const char* test_name, virtual void Init(const char* test_name,
const Metadata& metadata, data_size_t num_data) = 0; const Metadata& metadata, data_size_t num_data) = 0;
virtual const char* GetName() const = 0; virtual std::vector<std::string> GetName() const = 0;
virtual bool is_bigger_better() const = 0; virtual score_t factor_to_bigger_better() const = 0;
/*! /*!
* \brief Calcaluting and printing metric result * \brief Calcaluting and printing metric result
* \param score Current prediction score * \param score Current prediction score
......
...@@ -81,6 +81,16 @@ inline static std::vector<std::string> Split(const char* c_str, const char* deli ...@@ -81,6 +81,16 @@ inline static std::vector<std::string> Split(const char* c_str, const char* deli
return ret; return ret;
} }
template<typename T>
inline std::string Join(const std::vector<T>& data, char delimiters) {
std::stringstream result_stream_buf;
result_stream_buf << data[0];
for (size_t i = 1; i < data.size(); ++i) {
result_stream_buf << delimiters << data[i];
}
return result_stream_buf.str();
}
inline static const char* Atoi(const char* p, int* out) { inline static const char* Atoi(const char* p, int* out) {
int sign, value; int sign, value;
while (*p == ' ') { while (*p == ' ') {
...@@ -430,6 +440,54 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int float ...@@ -430,6 +440,54 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int float
} }
} }
inline std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int float_type, int is_row_major) {
if (float_type == 0) {
const float* dptr = reinterpret_cast<const float*>(data);
if (is_row_major) {
return [&dptr, &num_col, &num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
dptr += num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(dptr + i)));
}
return ret;
};
} else {
return [&dptr, &num_col, &num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(dptr + num_row * i + row_idx)));
}
return ret;
};
}
} else {
const double* dptr = reinterpret_cast<const double*>(data);
if (is_row_major) {
return [&dptr, &num_col, &num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
dptr += num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(dptr + i)));
}
return ret;
};
} else {
return [&dptr, &num_col, &num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(dptr + num_row * i + row_idx)));
}
return ret;
};
}
}
}
inline std::function<std::vector<std::pair<int, double>>(int idx)> inline std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const int32_t* indptr, const int32_t* indices, const void* data, int float_type, uint64_t nindptr, uint64_t nelem) { RowFunctionFromCSR(const int32_t* indptr, const int32_t* indices, const void* data, int float_type, uint64_t nindptr, uint64_t nelem) {
......
...@@ -125,19 +125,13 @@ void Application::LoadData() { ...@@ -125,19 +125,13 @@ void Application::LoadData() {
Predictor* predictor = nullptr; Predictor* predictor = nullptr;
// need to continue training // need to continue training
if (boosting_->NumberOfSubModels() > 0) { if (boosting_->NumberOfSubModels() > 0) {
predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index); predictor = new Predictor(boosting_, config_.io_config.is_raw_score, config_.predict_leaf_index);
if (config_.io_config.num_class == 1){ predict_fun =
predict_fun = [&predictor](const std::vector<std::pair<int, double>>& features) {
[&predictor](const std::vector<std::pair<int, double>>& features) { return predictor->PredictRawOneLine(features);
return predictor->PredictRawOneLine(features); };
};
} else {
predict_fun =
[&predictor](const std::vector<std::pair<int, double>>& features) {
return predictor->PredictMulticlassOneLine(features);
};
}
} }
// sync up random seed for data partition // sync up random seed for data partition
if (config_.is_parallel_find_bin) { if (config_.is_parallel_find_bin) {
config_.io_config.data_random_seed = config_.io_config.data_random_seed =
...@@ -262,7 +256,7 @@ void Application::Train() { ...@@ -262,7 +256,7 @@ void Application::Train() {
void Application::Predict() { void Application::Predict() {
boosting_->SetNumUsedModel(config_.io_config.num_model_predict); boosting_->SetNumUsedModel(config_.io_config.num_model_predict);
// create predictor // create predictor
Predictor predictor(boosting_, config_.io_config.is_sigmoid, Predictor predictor(boosting_, config_.io_config.is_raw_score,
config_.predict_leaf_index); config_.predict_leaf_index);
predictor.Predict(config_.io_config.data_filename.c_str(), predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header); config_.io_config.output_result.c_str(), config_.io_config.has_header);
......
...@@ -25,14 +25,13 @@ public: ...@@ -25,14 +25,13 @@ public:
/*! /*!
* \brief Constructor * \brief Constructor
* \param boosting Input boosting model * \param boosting Input boosting model
* \param is_sigmoid True if need to predict result with sigmoid transform (if needed, like binary classification) * \param is_raw_score True if need to predict result with raw score
* \param predict_leaf_index True if output leaf index instead of prediction score * \param predict_leaf_index True if output leaf index instead of prediction score
*/ */
Predictor(const Boosting* boosting, bool is_simgoid, bool is_predict_leaf_index) Predictor(const Boosting* boosting, bool is_raw_score, bool is_predict_leaf_index)
: is_simgoid_(is_simgoid), is_predict_leaf_index_(is_predict_leaf_index) { : is_raw_score_(is_raw_score), is_predict_leaf_index_(is_predict_leaf_index) {
boosting_ = boosting; boosting_ = boosting;
num_features_ = boosting_->MaxFeatureIdx() + 1; num_features_ = boosting_->MaxFeatureIdx() + 1;
num_class_ = boosting_->NumberOfClass();
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
{ {
...@@ -63,7 +62,7 @@ public: ...@@ -63,7 +62,7 @@ public:
std::vector<double> PredictRawOneLine(const std::vector<std::pair<int, double>>& features) { std::vector<double> PredictRawOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result without sigmoid transformation // get result without sigmoid transformation
return std::vector<double>(1, boosting_->PredictRaw(features_[tid])); return boosting_->PredictRaw(features_[tid]);
} }
/*! /*!
...@@ -85,18 +84,7 @@ public: ...@@ -85,18 +84,7 @@ public:
std::vector<double> PredictOneLine(const std::vector<std::pair<int, double>>& features) { std::vector<double> PredictOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed // get result with sigmoid transform if needed
return std::vector<double>(1, boosting_->Predict(features_[tid])); return boosting_->Predict(features_[tid]);
}
/*!
* \brief prediction for multiclass classification
* \param features Feature of this record
* \return Prediction result
*/
std::vector<double> PredictMulticlassOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed
return boosting_->PredictMulticlass(features_[tid]);
} }
/*! /*!
...@@ -132,42 +120,20 @@ public: ...@@ -132,42 +120,20 @@ public:
}; };
std::function<std::string(const std::vector<std::pair<int, double>>&)> predict_fun; std::function<std::string(const std::vector<std::pair<int, double>>&)> predict_fun;
if (num_class_ > 1) { if (is_predict_leaf_index_) {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){
std::vector<double> prediction = PredictMulticlassOneLine(features);
Common::Softmax(&prediction);
std::stringstream result_stream_buf;
for (size_t i = 0; i < prediction.size(); ++i){
if (i > 0) {
result_stream_buf << '\t';
}
result_stream_buf << prediction[i];
}
return result_stream_buf.str();
};
}
else if (is_predict_leaf_index_) {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){ predict_fun = [this](const std::vector<std::pair<int, double>>& features){
std::vector<int> predicted_leaf_index = PredictLeafIndexOneLine(features); return Common::Join<int>(PredictLeafIndexOneLine(features), '\t');
std::stringstream result_stream_buf;
for (size_t i = 0; i < predicted_leaf_index.size(); ++i){
if (i > 0) {
result_stream_buf << '\t';
}
result_stream_buf << predicted_leaf_index[i];
}
return result_stream_buf.str();
}; };
} }
else { else {
if (is_simgoid_) { if (is_raw_score_) {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){ predict_fun = [this](const std::vector<std::pair<int, double>>& features){
return std::to_string(PredictOneLine(features)[0]); return Common::Join<double>(PredictRawOneLine(features), '\t');
}; };
} }
else { else {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){ predict_fun = [this](const std::vector<std::pair<int, double>>& features){
return std::to_string(PredictRawOneLine(features)[0]); return Common::Join<double>(PredictOneLine(features), '\t');
}; };
} }
} }
...@@ -215,10 +181,8 @@ private: ...@@ -215,10 +181,8 @@ private:
double** features_; double** features_;
/*! \brief Number of features */ /*! \brief Number of features */
int num_features_; int num_features_;
/*! \brief Number of classes */
int num_class_;
/*! \brief True if need to predict result with sigmoid transform */ /*! \brief True if need to predict result with sigmoid transform */
bool is_simgoid_; bool is_raw_score_;
/*! \brief Number of threads */ /*! \brief Number of threads */
int num_threads_; int num_threads_;
/*! \brief True if output leaf index instead of prediction score */ /*! \brief True if output leaf index instead of prediction score */
......
...@@ -92,15 +92,22 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -92,15 +92,22 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
void GBDT::AddDataset(const Dataset* valid_data, void GBDT::AddDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) { const std::vector<const Metric*>& valid_metrics) {
if (iter_ > 0) {
Log::Fatal("Cannot add validation data after training started");
}
// for a validation dataset, we need its score and metric // for a validation dataset, we need its score and metric
valid_score_updater_.push_back(new ScoreUpdater(valid_data, num_class_)); valid_score_updater_.push_back(new ScoreUpdater(valid_data, num_class_));
valid_metrics_.emplace_back(); valid_metrics_.emplace_back();
best_iter_.emplace_back(); if (early_stopping_round_ > 0) {
best_score_.emplace_back(); best_iter_.emplace_back();
best_score_.emplace_back();
}
for (const auto& metric : valid_metrics) { for (const auto& metric : valid_metrics) {
valid_metrics_.back().push_back(metric); valid_metrics_.back().push_back(metric);
best_iter_.back().push_back(0); if (early_stopping_round_ > 0) {
best_score_.back().push_back(-1); best_iter_.back().push_back(0);
best_score_.back().push_back(kMinScore);
}
} }
} }
...@@ -231,7 +238,9 @@ bool GBDT::OutputMetric(int iter) { ...@@ -231,7 +238,9 @@ bool GBDT::OutputMetric(int iter) {
for (auto& sub_metric : training_metrics_) { for (auto& sub_metric : training_metrics_) {
auto name = sub_metric->GetName(); auto name = sub_metric->GetName();
auto scores = sub_metric->Eval(train_score_updater_->score()); auto scores = sub_metric->Eval(train_score_updater_->score());
Log::Info("Iteration: %d, %s: %s", iter, name, Common::ArrayToString<double>(scores, ' ').c_str()); for (size_t k = 0; k < name.size(); k++) {
Log::Info("Iteration: %d, %s : %f", iter, name[k].c_str(), scores[k]);
}
} }
} }
// print validation metric // print validation metric
...@@ -241,17 +250,17 @@ bool GBDT::OutputMetric(int iter) { ...@@ -241,17 +250,17 @@ bool GBDT::OutputMetric(int iter) {
auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score()); auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score());
if ((iter % gbdt_config_->output_freq) == 0) { if ((iter % gbdt_config_->output_freq) == 0) {
auto name = valid_metrics_[i][j]->GetName(); auto name = valid_metrics_[i][j]->GetName();
Log::Info("Iteration: %d, %s: %s", iter, name, Common::ArrayToString<double>(test_scores, ' ').c_str()); for (size_t k = 0; k < name.size(); k++) {
Log::Info("Iteration: %d, %s : %f", iter, name[k].c_str(), test_scores[k]);
}
} }
if (!ret && early_stopping_round_ > 0) { if (!ret && early_stopping_round_ > 0) {
bool the_bigger_the_better = valid_metrics_[i][j]->is_bigger_better(); auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
if (best_score_[i][j] < 0 if (cur_score > best_score_[i][j]) {
|| (!the_bigger_the_better && test_scores.back() < best_score_[i][j]) best_score_[i][j] = cur_score;
|| (the_bigger_the_better && test_scores.back() > best_score_[i][j])) {
best_score_[i][j] = test_scores.back();
best_iter_[i][j] = iter; best_iter_[i][j] = iter;
} else { } else {
if (iter - best_iter_[i][j] >= early_stopping_round_) ret = true; if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = true; }
} }
} }
} }
...@@ -261,40 +270,40 @@ bool GBDT::OutputMetric(int iter) { ...@@ -261,40 +270,40 @@ bool GBDT::OutputMetric(int iter) {
} }
/*! \brief Get eval result */ /*! \brief Get eval result */
std::vector<std::string> GBDT::EvalCurrent(bool is_eval_train) const { std::vector<double> GBDT::GetEvalAt(int data_idx) const {
std::vector<std::string> ret; CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_metrics_.size()));
if (is_eval_train) { std::vector<double> ret;
if (data_idx == 0) {
for (auto& sub_metric : training_metrics_) { for (auto& sub_metric : training_metrics_) {
auto name = sub_metric->GetName();
auto scores = sub_metric->Eval(train_score_updater_->score()); auto scores = sub_metric->Eval(train_score_updater_->score());
std::stringstream str_buf; for (auto score : scores) {
str_buf << name << " : " << Common::ArrayToString<double>(scores, ' '); ret.push_back(score);
ret.emplace_back(str_buf.str()); }
} }
} }
else {
for (size_t i = 0; i < valid_metrics_.size(); ++i) { auto used_idx = data_idx - 1;
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) { for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
auto name = valid_metrics_[i][j]->GetName(); auto test_scores = valid_metrics_[used_idx][j]->Eval(valid_score_updater_[used_idx]->score());
auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score()); for (auto score : test_scores) {
std::stringstream str_buf; ret.push_back(score);
str_buf << name << " : " << Common::ArrayToString<double>(test_scores, ' '); }
ret.emplace_back(str_buf.str());
} }
} }
return ret; return ret;
} }
/*! \brief Get prediction result */ /*! \brief Get prediction result */
const std::vector<const score_t*> GBDT::PredictCurrent(bool is_predict_train) const { const score_t* GBDT::GetScoreAt(int data_idx, data_size_t* out_len) const {
std::vector<const score_t*> ret; CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
if (is_predict_train) { if (data_idx == 0) {
ret.push_back(train_score_updater_->score()); *out_len = train_score_updater_->num_data() * num_class_;
} return train_score_updater_->score();
for (size_t i = 0; i < valid_metrics_.size(); ++i) { } else {
ret.push_back(valid_score_updater_[i]->score()); auto used_idx = data_idx - 1;
*out_len = valid_score_updater_[used_idx]->num_data() * num_class_;
return valid_score_updater_[used_idx]->score();
} }
return ret;
} }
void GBDT::Boosting() { void GBDT::Boosting() {
...@@ -470,40 +479,38 @@ std::string GBDT::FeatureImportance() const { ...@@ -470,40 +479,38 @@ std::string GBDT::FeatureImportance() const {
return str_buf.str(); return str_buf.str();
} }
double GBDT::PredictRaw(const double* value) const { std::vector<double> GBDT::PredictRaw(const double* value) const {
double ret = 0.0f; std::vector<double> ret(num_class_, 0.0f);
for (int i = 0; i < num_used_model_; ++i) {
ret += models_[i]->Predict(value);
}
return ret;
}
double GBDT::Predict(const double* value) const {
double ret = 0.0f;
for (int i = 0; i < num_used_model_; ++i) { for (int i = 0; i < num_used_model_; ++i) {
ret += models_[i]->Predict(value); for (int j = 0; j < num_class_; ++j) {
} ret[j] += models_[i * num_class_ + j]->Predict(value);
// if need sigmoid transform }
if (sigmoid_ > 0) {
ret = 1.0f / (1.0f + std::exp(- 2.0f * sigmoid_ * ret));
} }
return ret; return ret;
} }
std::vector<double> GBDT::PredictMulticlass(const double* value) const { std::vector<double> GBDT::Predict(const double* value) const {
std::vector<double> ret(num_class_, 0.0f); std::vector<double> ret(num_class_, 0.0f);
for (int i = 0; i < num_used_model_; ++i) { for (int i = 0; i < num_used_model_; ++i) {
for (int j = 0; j < num_class_; ++j){ for (int j = 0; j < num_class_; ++j) {
ret[j] += models_[i * num_class_ + j] -> Predict(value); ret[j] += models_[i * num_class_ + j]->Predict(value);
} }
} }
// if need sigmoid transform
if (sigmoid_ > 0 && num_class_ == 1) {
ret[0] = 1.0f / (1.0f + std::exp(- 2.0f * sigmoid_ * ret[0]));
} else if (num_class_ > 1) {
Common::Softmax(&ret);
}
return ret; return ret;
} }
std::vector<int> GBDT::PredictLeafIndex(const double* value) const { std::vector<int> GBDT::PredictLeafIndex(const double* value) const {
std::vector<int> ret; std::vector<int> ret;
for (int i = 0; i < num_used_model_; ++i) { for (int i = 0; i < num_used_model_; ++i) {
ret.push_back(models_[i]->PredictLeafIndex(value)); for (int j = 0; j < num_class_; ++j) {
ret.push_back(models_[i * num_class_ + j]->PredictLeafIndex(value));
}
} }
return ret; return ret;
} }
......
...@@ -46,32 +46,25 @@ public: ...@@ -46,32 +46,25 @@ public:
*/ */
bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override; bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override;
/*! \brief Get eval result */
std::vector<std::string> EvalCurrent(bool is_eval_train) const override; std::vector<double> GetEvalAt(int data_idx) const override;
/*! \brief Get prediction result */ /*! \brief Get prediction result */
const std::vector<const score_t*> PredictCurrent(bool is_predict_train) const override; const score_t* GetScoreAt(int data_idx, data_size_t* out_len) const override;
/*! /*!
* \brief Predtion for one record without sigmoid transformation * \brief Predtion for one record without sigmoid transformation
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result for this record * \return Prediction result for this record
*/ */
double PredictRaw(const double* feature_values) const override; std::vector<double> PredictRaw(const double* feature_values) const override;
/*! /*!
* \brief Predtion for one record with sigmoid transformation if enabled * \brief Predtion for one record with sigmoid transformation if enabled
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result for this record * \return Prediction result for this record
*/ */
double Predict(const double* feature_values) const override; std::vector<double> Predict(const double* feature_values) const override;
/*!
* \brief Predtion for multiclass classification
* \param feature_values Feature value on this record
* \return Prediction result, num_class numbers per line
*/
std::vector<double> PredictMulticlass(const double* value) const override;
/*! /*!
* \brief Predtion for one record with leaf index * \brief Predtion for one record with leaf index
......
...@@ -65,8 +65,8 @@ public: ...@@ -65,8 +65,8 @@ public:
tree->AddPredictionToScore(data_, data_indices, data_cnt, score_ + curr_class * num_data_); tree->AddPredictionToScore(data_, data_indices, data_cnt, score_ + curr_class * num_data_);
} }
/*! \brief Pointer of score */ /*! \brief Pointer of score */
inline const score_t * score() { return score_; } inline const score_t* score() { return score_; }
inline const data_size_t num_data() { return num_data_; }
private: private:
/*! \brief Number of total data */ /*! \brief Number of total data */
data_size_t num_data_; data_size_t num_data_;
......
...@@ -88,7 +88,25 @@ public: ...@@ -88,7 +88,25 @@ public:
if (boosting_ != nullptr) { delete boosting_; } if (boosting_ != nullptr) { delete boosting_; }
if (objective_fun_ != nullptr) { delete objective_fun_; } if (objective_fun_ != nullptr) { delete objective_fun_; }
} }
bool TrainOneIter() {
return boosting_->TrainOneIter(nullptr, nullptr, false);
}
bool TrainOneIter(const float* gradients, const float* hessians) {
return boosting_->TrainOneIter(gradients, hessians, false);
}
void PrepareForPrediction(int num_used_model, int predict_type) {
boosting_->SetNumUsedModel(num_used_model);
}
const Boosting* GetBoosting() const { return boosting_; }
private: private:
std::function<std::vector<double>(const std::vector<std::pair<int, double>>&)> predict_fun;
Boosting* boosting_; Boosting* boosting_;
/*! \brief All configs */ /*! \brief All configs */
OverallConfig config_; OverallConfig config_;
...@@ -102,6 +120,7 @@ private: ...@@ -102,6 +120,7 @@ private:
std::vector<std::vector<Metric*>> valid_metrics_; std::vector<std::vector<Metric*>> valid_metrics_;
/*! \brief Training objective function */ /*! \brief Training objective function */
ObjectiveFunction* objective_fun_; ObjectiveFunction* objective_fun_;
}; };
} }
...@@ -170,7 +189,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, ...@@ -170,7 +189,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
} }
ret = loader.CostructFromSampleData(sample_values, nrow); ret = loader.CostructFromSampleData(sample_values, nrow);
} else { } else {
ret = new Dataset(nrow); ret = new Dataset(nrow, config.io_config.num_class);
reinterpret_cast<const Dataset*>(*reference)->CopyFeatureBinMapperTo(ret, config.io_config.is_enable_sparse); reinterpret_cast<const Dataset*>(*reference)->CopyFeatureBinMapperTo(ret, config.io_config.is_enable_sparse);
} }
...@@ -231,7 +250,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr, ...@@ -231,7 +250,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
CHECK(num_col >= sample_values.size()); CHECK(num_col >= sample_values.size());
ret = loader.CostructFromSampleData(sample_values, nrow); ret = loader.CostructFromSampleData(sample_values, nrow);
} else { } else {
ret = new Dataset(nrow); ret = new Dataset(nrow, config.io_config.num_class);
reinterpret_cast<const Dataset*>(*reference)->CopyFeatureBinMapperTo(ret, config.io_config.is_enable_sparse); reinterpret_cast<const Dataset*>(*reference)->CopyFeatureBinMapperTo(ret, config.io_config.is_enable_sparse);
} }
...@@ -278,7 +297,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr, ...@@ -278,7 +297,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr,
} }
ret = loader.CostructFromSampleData(sample_values, nrow); ret = loader.CostructFromSampleData(sample_values, nrow);
} else { } else {
ret = new Dataset(nrow); ret = new Dataset(nrow, config.io_config.num_class);
reinterpret_cast<const Dataset*>(*reference)->CopyFeatureBinMapperTo(ret, config.io_config.is_enable_sparse); reinterpret_cast<const Dataset*>(*reference)->CopyFeatureBinMapperTo(ret, config.io_config.is_enable_sparse);
} }
...@@ -341,3 +360,165 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, ...@@ -341,3 +360,165 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
*out = dataset->num_total_features(); *out = dataset->num_total_features();
return 0; return 0;
} }
// ---- start of booster
DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const DatesetHandle valid_datas[],
const char* valid_names[],
int n_valid_datas,
const char* parameters,
BoosterHandle* out) {
const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
std::vector<const Dataset*> p_valid_datas;
std::vector<std::string> p_valid_names;
for (int i = 0; i < n_valid_datas; ++i) {
p_valid_datas.emplace_back(reinterpret_cast<const Dataset*>(valid_datas[i]));
p_valid_names.emplace_back(valid_names[i]);
}
*out = new Booster(p_train_data, p_valid_datas, p_valid_names, parameters);
return 0;
}
DllExport int LGBM_BoosterLoadFromModelfile(
const char* filename,
BoosterHandle* out) {
*out = new Booster(filename);
return 0;
}
DllExport int LGBM_BoosterFree(BoosterHandle handle) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
delete ref_booster;
return 0;
}
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
if (ref_booster->TrainOneIter()) {
*is_finished = 1;
} else {
*is_finished = 0;
}
return 0;
}
DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const float* grad,
const float* hess,
int* is_finished) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
if (ref_booster->TrainOneIter(grad, hess)) {
*is_finished = 1;
} else {
*is_finished = 0;
}
return 0;
}
DllExport int LGBM_BoosterEval(BoosterHandle handle,
int data,
uint64_t* out_len,
double* out_results) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting();
auto result_buf = boosting->GetEvalAt(data);
*out_len = static_cast<uint64_t>(result_buf.size());
for (size_t i = 0; i < result_buf.size(); ++i) {
(out_results)[i] = result_buf[i];
}
return 0;
}
DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
int data,
uint64_t* out_len,
const float** out_result) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting();
int len = 0;
*out_result = boosting->GetScoreAt(data, &len);
*out_len = static_cast<uint64_t>(len);
return 0;
}
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
const int32_t* indptr,
const int32_t* indices,
const void* data,
int float_type,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_col,
int predict_type,
uint64_t n_used_trees,
double* out_result);
/*!
* \brief make prediction for an new data set
* \param handle handle
* \param col_ptr pointer to col headers
* \param indices findex
* \param data fvalue
* \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data
* \param predict_type
* 0:raw score
* 1:with sigmoid transform(if needed)
* 2:leaf index
* \param n_used_trees number of used tree
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
const int32_t* col_ptr,
const int32_t* indices,
const void* data,
int float_type,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_row,
int predict_type,
uint64_t n_used_trees,
double* out_result);
/*!
* \brief make prediction for an new data set
* \param handle handle
* \param data pointer to the data space
* \param nrow number of rows
* \param ncol number columns
* \param missing which value to represent missing value
* \param predict_type
* 0:raw score
* 1:with sigmoid transform(if needed)
* 2:leaf index
* \param n_used_trees number of used tree
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
const void* data,
int float_type,
int32_t nrow,
int32_t ncol,
int predict_type,
uint64_t n_used_trees,
double* out_result);
/*!
* \brief save model into file
* \param handle handle
* \param num_used_model
* \param filename file name
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_used_model,
const char* filename);
...@@ -198,7 +198,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -198,7 +198,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetBool(params, "use_two_round_loading", &use_two_round_loading); GetBool(params, "use_two_round_loading", &use_two_round_loading);
GetBool(params, "is_save_binary_file", &is_save_binary_file); GetBool(params, "is_save_binary_file", &is_save_binary_file);
GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file); GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
GetBool(params, "is_sigmoid", &is_sigmoid); GetBool(params, "is_raw_score", &is_raw_score);
GetString(params, "output_model", &output_model); GetString(params, "output_model", &output_model);
GetString(params, "input_model", &input_model); GetString(params, "input_model", &input_model);
GetString(params, "output_result", &output_result); GetString(params, "output_result", &output_result);
......
...@@ -21,10 +21,11 @@ Dataset::Dataset() { ...@@ -21,10 +21,11 @@ Dataset::Dataset() {
is_loading_from_binfile_ = false; is_loading_from_binfile_ = false;
} }
Dataset::Dataset(data_size_t num_data) { Dataset::Dataset(data_size_t num_data, int num_class) {
num_class_ = 1; num_class_ = num_class;
num_data_ = num_data; num_data_ = num_data;
is_loading_from_binfile_ = false; is_loading_from_binfile_ = false;
metadata_.Init(num_data_, num_class_, -1, -1);
} }
Dataset::~Dataset() { Dataset::~Dataset() {
......
...@@ -437,6 +437,7 @@ Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>& ...@@ -437,6 +437,7 @@ Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>&
delete bin_mappers[i]; delete bin_mappers[i];
} }
} }
dataset->metadata_.Init(dataset->num_data_, dataset->num_class_, -1, -1);
return dataset; return dataset;
} }
......
...@@ -30,9 +30,11 @@ public: ...@@ -30,9 +30,11 @@ public:
} }
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << test_name << "'s " << PointWiseLossCalculator::Name(); str_buf << test_name << "'s : " << PointWiseLossCalculator::Name();
name_ = str_buf.str(); name_.emplace_back(str_buf.str());
num_data_ = num_data; num_data_ = num_data;
// get label // get label
label_ = metadata.label(); label_ = metadata.label();
...@@ -50,12 +52,12 @@ public: ...@@ -50,12 +52,12 @@ public:
} }
} }
const char* GetName() const override { std::vector<std::string> GetName() const override {
return name_.c_str(); return name_;
} }
bool is_bigger_better() const override { score_t factor_to_bigger_better() const override {
return false; return -1.0f;
} }
std::vector<double> Eval(const score_t* score) const override { std::vector<double> Eval(const score_t* score) const override {
...@@ -91,7 +93,7 @@ private: ...@@ -91,7 +93,7 @@ private:
/*! \brief Sum weights */ /*! \brief Sum weights */
double sum_weights_; double sum_weights_;
/*! \brief Name of test set */ /*! \brief Name of test set */
std::string name_; std::vector<std::string> name_;
/*! \brief Sigmoid parameter */ /*! \brief Sigmoid parameter */
score_t sigmoid_; score_t sigmoid_;
}; };
...@@ -152,18 +154,18 @@ public: ...@@ -152,18 +154,18 @@ public:
virtual ~AUCMetric() { virtual ~AUCMetric() {
} }
const char* GetName() const override { std::vector<std::string> GetName() const override {
return name_.c_str(); return name_;
} }
bool is_bigger_better() const override { score_t factor_to_bigger_better() const override {
return true; return 1.0f;
} }
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << test_name << "'s AUC"; str_buf << test_name << "'s : AUC";
name_ = str_buf.str(); name_.emplace_back(str_buf.str());
num_data_ = num_data; num_data_ = num_data;
// get label // get label
...@@ -250,7 +252,7 @@ private: ...@@ -250,7 +252,7 @@ private:
/*! \brief Sum weights */ /*! \brief Sum weights */
double sum_weights_; double sum_weights_;
/*! \brief Name of test set */ /*! \brief Name of test set */
std::string name_; std::vector<std::string> name_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -25,8 +25,8 @@ public: ...@@ -25,8 +25,8 @@ public:
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << test_name << "'s " << PointWiseLossCalculator::Name(); str_buf << test_name << " : " << PointWiseLossCalculator::Name();
name_ = str_buf.str(); name_.emplace_back(str_buf.str());
num_data_ = num_data; num_data_ = num_data;
// get label // get label
label_ = metadata.label(); label_ = metadata.label();
...@@ -42,12 +42,12 @@ public: ...@@ -42,12 +42,12 @@ public:
} }
} }
const char* GetName() const override { std::vector<std::string> GetName() const override {
return name_.c_str(); return name_;
} }
bool is_bigger_better() const override { score_t factor_to_bigger_better() const override {
return false; return -1.0f;
} }
std::vector<double> Eval(const score_t* score) const override { std::vector<double> Eval(const score_t* score) const override {
...@@ -91,7 +91,7 @@ private: ...@@ -91,7 +91,7 @@ private:
/*! \brief Sum weights */ /*! \brief Sum weights */
double sum_weights_; double sum_weights_;
/*! \brief Name of this test set */ /*! \brief Name of this test set */
std::string name_; std::vector<std::string> name_;
}; };
/*! \brief L2 loss for multiclass task */ /*! \brief L2 loss for multiclass task */
......
...@@ -33,12 +33,12 @@ public: ...@@ -33,12 +33,12 @@ public:
~NDCGMetric() { ~NDCGMetric() {
} }
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf;
str_buf << test_name << "'s ";
for (auto k : eval_at_) { for (auto k : eval_at_) {
std::stringstream str_buf;
str_buf << test_name << "'s : ";
str_buf << "NDCG@" + std::to_string(k) + " "; str_buf << "NDCG@" + std::to_string(k) + " ";
name_.emplace_back(str_buf.str());
} }
name_ = str_buf.str();
num_data_ = num_data; num_data_ = num_data;
// get label // get label
label_ = metadata.label(); label_ = metadata.label();
...@@ -76,12 +76,12 @@ public: ...@@ -76,12 +76,12 @@ public:
} }
} }
const char* GetName() const override { std::vector<std::string> GetName() const override {
return name_.c_str(); return name_;
} }
bool is_bigger_better() const override { score_t factor_to_bigger_better() const override {
return true; return 1.0f;
} }
std::vector<double> Eval(const score_t* score) const override { std::vector<double> Eval(const score_t* score) const override {
...@@ -149,7 +149,7 @@ private: ...@@ -149,7 +149,7 @@ private:
/*! \brief Pointer of label */ /*! \brief Pointer of label */
const float* label_; const float* label_;
/*! \brief Name of test set */ /*! \brief Name of test set */
std::string name_; std::vector<std::string> name_;
/*! \brief Query boundaries information */ /*! \brief Query boundaries information */
const data_size_t* query_boundaries_; const data_size_t* query_boundaries_;
/*! \brief Number of queries */ /*! \brief Number of queries */
......
...@@ -23,18 +23,18 @@ public: ...@@ -23,18 +23,18 @@ public:
} }
const char* GetName() const override { std::vector<std::string> GetName() const override {
return name_.c_str(); return name_;
} }
bool is_bigger_better() const override { score_t factor_to_bigger_better() const override {
return false; return -1.0f;
} }
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << test_name << "'s " << PointWiseLossCalculator::Name(); str_buf << test_name << " : " << PointWiseLossCalculator::Name();
name_ = str_buf.str(); name_.emplace_back(str_buf.str());
num_data_ = num_data; num_data_ = num_data;
// get label // get label
...@@ -85,7 +85,7 @@ private: ...@@ -85,7 +85,7 @@ private:
/*! \brief Sum weights */ /*! \brief Sum weights */
double sum_weights_; double sum_weights_;
/*! \brief Name of this test set */ /*! \brief Name of this test set */
std::string name_; std::vector<std::string> name_;
}; };
/*! \brief L2 loss for regression task */ /*! \brief L2 loss for regression task */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment