Commit 422c0ef7 authored by Guolin Ke's avatar Guolin Ke
Browse files

almost finish, need some tests

parent fc383361
...@@ -73,7 +73,7 @@ public: ...@@ -73,7 +73,7 @@ public:
* \param result used to store prediction result, should allocate memory before call this function * \param result used to store prediction result, should allocate memory before call this function
* \param out_len lenght of returned score * \param out_len lenght of returned score
*/ */
virtual void GetPredictAt(int data_idx, score_t* result, data_size_t* out_len) const = 0; virtual void GetPredictAt(int data_idx, score_t* result, data_size_t* out_len) = 0;
/*! /*!
* \brief Prediction for one record, not sigmoid transform * \brief Prediction for one record, not sigmoid transform
...@@ -127,7 +127,7 @@ public: ...@@ -127,7 +127,7 @@ public:
* \brief Get number of weak sub-models * \brief Get number of weak sub-models
* \return Number of weak sub-models * \return Number of weak sub-models
*/ */
virtual int NumberOfSubModels() const = 0; virtual int NumberOfTotalModel() const = 0;
/*! /*!
* \brief Get number of classes * \brief Get number of classes
...@@ -138,7 +138,7 @@ public: ...@@ -138,7 +138,7 @@ public:
/*! /*!
* \brief Set number of used model for prediction * \brief Set number of used model for prediction
*/ */
virtual void SetNumUsedModel(int num_used_model) = 0; virtual void SetNumIterationForPred(int num_iteration) = 0;
/*! /*!
* \brief Get Type name of this boosting object * \brief Get Type name of this boosting object
......
...@@ -230,11 +230,13 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, ...@@ -230,11 +230,13 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
/*! /*!
* \brief load an existing boosting from model file * \brief load an existing boosting from model file
* \param filename filename of model * \param filename filename of model
* \param out_num_total_model number of total models
* \param out handle of created Booster * \param out handle of created Booster
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterCreateFromModelfile( DllExport int LGBM_BoosterCreateFromModelfile(
const char* filename, const char* filename,
int64_t* out_num_total_model,
BoosterHandle* out); BoosterHandle* out);
/*! /*!
...@@ -244,6 +246,12 @@ DllExport int LGBM_BoosterCreateFromModelfile( ...@@ -244,6 +246,12 @@ DllExport int LGBM_BoosterCreateFromModelfile(
*/ */
DllExport int LGBM_BoosterFree(BoosterHandle handle); DllExport int LGBM_BoosterFree(BoosterHandle handle);
/*!
* \brief Get number of class
* \return number of class
*/
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len);
/*! /*!
* \brief update the model in one round * \brief update the model in one round
* \param handle handle * \param handle handle
...@@ -276,7 +284,7 @@ DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len); ...@@ -276,7 +284,7 @@ DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len);
* \brief Get number of eval * \brief Get number of eval
* \return total number of eval result * \return total number of eval result
*/ */
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, const char*** out_strs); DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs);
/*! /*!
* \brief get evaluation for training data and validation data * \brief get evaluation for training data and validation data
...@@ -291,17 +299,6 @@ DllExport int LGBM_BoosterGetEval(BoosterHandle handle, ...@@ -291,17 +299,6 @@ DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
int64_t* out_len, int64_t* out_len,
float* out_results); float* out_results);
/*!
* \brief get raw score for training data, used to calculate gradients outside
* \param handle handle
* \param out_len len of output result
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterGetTrainingScore(BoosterHandle handle,
int64_t* out_len,
const float** out_result);
/*! /*!
* \brief Get prediction for training data and validation data * \brief Get prediction for training data and validation data
this can be used to support customized eval function this can be used to support customized eval function
...@@ -319,21 +316,21 @@ DllExport int LGBM_BoosterGetPredict(BoosterHandle handle, ...@@ -319,21 +316,21 @@ DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
/*! /*!
* \brief make prediction for file * \brief make prediction for file
* \param handle handle * \param handle handle
* \param data_filename filename of data file
* \param data_has_header data file has header or not
* \param predict_type * \param predict_type
* 0:raw score * 0:raw score
* 1:with transform(if needed) * 1:with transform(if needed)
* 2:leaf index * 2:leaf index
* \param n_used_trees number of used tree * \param num_iteration number of iteration for prediction
* \param data_has_header data file has header or not
* \param data_filename filename of data file
* \param result_filename filename of result file * \param result_filename filename of result file
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
int predict_type,
int64_t n_used_trees,
int data_has_header,
const char* data_filename, const char* data_filename,
int data_has_header,
int predict_type,
int64_t num_iteration,
const char* result_filename); const char* result_filename);
/*! /*!
...@@ -351,7 +348,8 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -351,7 +348,8 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
* 0:raw score * 0:raw score
* 1:with transform(if needed) * 1:with transform(if needed)
* 2:leaf index * 2:leaf index
* \param n_used_trees number of used tree * \param num_iteration number of iteration for prediction
* \param out_len len of output result
* \param out_result used to set a pointer to array, should allocate memory before call this function * \param out_result used to set a pointer to array, should allocate memory before call this function
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
...@@ -365,8 +363,9 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -365,8 +363,9 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t nelem, int64_t nelem,
int64_t num_col, int64_t num_col,
int predict_type, int predict_type,
int64_t n_used_trees, int64_t num_iteration,
double* out_result); int64_t* out_len,
float* out_result);
/*! /*!
* \brief make prediction for an new data set * \brief make prediction for an new data set
...@@ -380,7 +379,8 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -380,7 +379,8 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
* 0:raw score * 0:raw score
* 1:with transform(if needed) * 1:with transform(if needed)
* 2:leaf index * 2:leaf index
* \param n_used_trees number of used tree * \param num_iteration number of iteration for prediction
* \param out_len len of output result
* \param out_result used to set a pointer to array, should allocate memory before call this function * \param out_result used to set a pointer to array, should allocate memory before call this function
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
...@@ -391,18 +391,19 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -391,18 +391,19 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
int32_t ncol, int32_t ncol,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int64_t n_used_trees, int64_t num_iteration,
double* out_result); int64_t* out_len,
float* out_result);
/*! /*!
* \brief save model into file * \brief save model into file
* \param handle handle * \param handle handle
* \param num_used_model * \param num_iteration
* \param filename file name * \param filename file name
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterSaveModel(BoosterHandle handle, DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_used_model, int num_iteration,
const char* filename); const char* filename);
......
...@@ -97,7 +97,7 @@ public: ...@@ -97,7 +97,7 @@ public:
std::string output_result = "LightGBM_predict_result.txt"; std::string output_result = "LightGBM_predict_result.txt";
std::string input_model = ""; std::string input_model = "";
int verbosity = 1; int verbosity = 1;
int num_model_predict = NO_LIMIT; int num_iteration_predict = NO_LIMIT;
bool is_pre_partition = false; bool is_pre_partition = false;
bool is_enable_sparse = true; bool is_enable_sparse = true;
bool use_two_round_loading = false; bool use_two_round_loading = false;
......
This diff is collapsed.
...@@ -108,7 +108,7 @@ void Application::LoadData() { ...@@ -108,7 +108,7 @@ void Application::LoadData() {
// prediction is needed if using input initial model(continued train) // prediction is needed if using input initial model(continued train)
PredictFunction predict_fun = nullptr; PredictFunction predict_fun = nullptr;
// need to continue training // need to continue training
if (boosting_->NumberOfSubModels() > 0) { if (boosting_->NumberOfTotalModel() > 0) {
Predictor predictor(boosting_.get(), true, false); Predictor predictor(boosting_.get(), true, false);
predict_fun = predictor.GetPredictFunction(); predict_fun = predictor.GetPredictFunction();
} }
...@@ -235,7 +235,7 @@ void Application::Train() { ...@@ -235,7 +235,7 @@ void Application::Train() {
void Application::Predict() { void Application::Predict() {
boosting_->SetNumUsedModel(config_.io_config.num_model_predict); boosting_->SetNumIterationForPred(config_.io_config.num_iteration_predict);
// create predictor // create predictor
Predictor predictor(boosting_.get(), config_.io_config.is_predict_raw_score, Predictor predictor(boosting_.get(), config_.io_config.is_predict_raw_score,
config_.io_config.is_predict_leaf_index); config_.io_config.is_predict_leaf_index);
......
...@@ -43,6 +43,7 @@ public: ...@@ -43,6 +43,7 @@ public:
* \brief one training iteration * \brief one training iteration
*/ */
bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override { bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override {
is_update_score_cur_iter_ = false;
GBDT::TrainOneIter(gradient, hessian, false); GBDT::TrainOneIter(gradient, hessian, false);
// normalize // normalize
Normalize(); Normalize();
...@@ -58,20 +59,24 @@ public: ...@@ -58,20 +59,24 @@ public:
* \return training score * \return training score
*/ */
const score_t* GetTrainingScore(data_size_t* out_len) override { const score_t* GetTrainingScore(data_size_t* out_len) override {
if (!is_update_score_cur_iter_) {
// only drop one time in one iteration
DroppingTrees(); DroppingTrees();
is_update_score_cur_iter_ = true;
}
*out_len = train_score_updater_->num_data() * num_class_; *out_len = train_score_updater_->num_data() * num_class_;
return train_score_updater_->score(); return train_score_updater_->score();
} }
/*! /*!
* \brief save model to file * \brief save model to file
* \param num_used_model number of model that want to save, -1 means save all * \param num_iteration -1 means save all
* \param is_finish is training finished or not * \param is_finish is training finished or not
* \param filename filename that want to save to * \param filename filename that want to save to
*/ */
void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) override { void SaveModelToFile(int num_iteration, bool is_finish, const char* filename) override {
// only save model once when is_finish = true // only save model once when is_finish = true
if (is_finish && saved_model_size_ < 0) { if (is_finish && saved_model_size_ < 0) {
GBDT::SaveModelToFile(num_used_model, is_finish, filename); GBDT::SaveModelToFile(num_iteration, is_finish, filename);
} }
} }
/*! /*!
...@@ -133,6 +138,8 @@ private: ...@@ -133,6 +138,8 @@ private:
double drop_rate_; double drop_rate_;
/*! \brief Random generator, used to select dropping trees */ /*! \brief Random generator, used to select dropping trees */
Random random_for_drop_; Random random_for_drop_;
/*! \brief Flag that the score is update on current iter or not*/
bool is_update_score_cur_iter_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
namespace LightGBM { namespace LightGBM {
GBDT::GBDT() : saved_model_size_(-1), num_used_model_(0) { GBDT::GBDT() : saved_model_size_(-1), num_iteration_for_pred_(0) {
} }
...@@ -29,7 +29,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -29,7 +29,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
gbdt_config_ = config; gbdt_config_ = config;
iter_ = 0; iter_ = 0;
saved_model_size_ = -1; saved_model_size_ = -1;
num_used_model_ = 0; num_iteration_for_pred_ = 0;
max_feature_idx_ = 0; max_feature_idx_ = 0;
early_stopping_round_ = gbdt_config_->early_stopping_round; early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate; shrinkage_rate_ = gbdt_config_->learning_rate;
...@@ -296,24 +296,23 @@ const score_t* GBDT::GetTrainingScore(data_size_t* out_len) { ...@@ -296,24 +296,23 @@ const score_t* GBDT::GetTrainingScore(data_size_t* out_len) {
return train_score_updater_->score(); return train_score_updater_->score();
} }
void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) const { void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_metrics_.size())); CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_metrics_.size()));
std::vector<double> ret; std::vector<double> ret;
const score_t* raw_scores = nullptr; const score_t* raw_scores = nullptr;
data_size_t num_data = 0; data_size_t num_data = 0;
if (data_idx == 0) { if (data_idx == 0) {
raw_scores = train_score_updater_->score(); raw_scores = GetTrainingScore(out_len);
num_data = train_score_updater_->num_data(); num_data = train_score_updater_->num_data();
} else { } else {
auto used_idx = data_idx - 1; auto used_idx = data_idx - 1;
raw_scores = valid_score_updater_[used_idx]->score(); raw_scores = valid_score_updater_[used_idx]->score();
num_data = valid_score_updater_[used_idx]->num_data(); num_data = valid_score_updater_[used_idx]->num_data();
}
*out_len = num_data * num_class_; *out_len = num_data * num_class_;
}
if (num_class_ > 1) { if (num_class_ > 1) {
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
std::vector<double> tmp_result; std::vector<double> tmp_result;
for (int j = 0; j < num_class_; ++j) { for (int j = 0; j < num_class_; ++j) {
...@@ -325,12 +324,12 @@ void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) ...@@ -325,12 +324,12 @@ void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len)
} }
} }
} else if(sigmoid_ > 0.0f){ } else if(sigmoid_ > 0.0f){
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
out_result[i] = static_cast<score_t>(1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * raw_scores[i]))); out_result[i] = static_cast<score_t>(1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * raw_scores[i])));
} }
} else { } else {
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
out_result[i] = raw_scores[i]; out_result[i] = raw_scores[i];
} }
...@@ -348,7 +347,7 @@ void GBDT::Boosting() { ...@@ -348,7 +347,7 @@ void GBDT::Boosting() {
GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data()); GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data());
} }
void GBDT::SaveModelToFile(int num_used_model, bool is_finish, const char* filename) { void GBDT::SaveModelToFile(int num_iteration, bool is_finish, const char* filename) {
// first time to this function, open file // first time to this function, open file
if (saved_model_size_ < 0) { if (saved_model_size_ < 0) {
model_output_file_.open(filename); model_output_file_.open(filename);
...@@ -373,10 +372,11 @@ void GBDT::SaveModelToFile(int num_used_model, bool is_finish, const char* filen ...@@ -373,10 +372,11 @@ void GBDT::SaveModelToFile(int num_used_model, bool is_finish, const char* filen
if (!model_output_file_.is_open()) { if (!model_output_file_.is_open()) {
return; return;
} }
if (num_used_model == NO_LIMIT) { int num_used_model = 0;
if (num_iteration == NO_LIMIT) {
num_used_model = static_cast<int>(models_.size()); num_used_model = static_cast<int>(models_.size());
} else { } else {
num_used_model = num_used_model * num_class_; num_used_model = num_iteration * num_class_;
} }
int rest = num_used_model - early_stopping_round_ * num_class_; int rest = num_used_model - early_stopping_round_ * num_class_;
// output tree models // output tree models
...@@ -452,7 +452,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) { ...@@ -452,7 +452,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
} }
} }
Log::Info("Finished loading %d models", models_.size()); Log::Info("Finished loading %d models", models_.size());
num_used_model_ = static_cast<int>(models_.size()) / num_class_; num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
} }
std::string GBDT::FeatureImportance() const { std::string GBDT::FeatureImportance() const {
...@@ -486,7 +486,7 @@ std::string GBDT::FeatureImportance() const { ...@@ -486,7 +486,7 @@ std::string GBDT::FeatureImportance() const {
std::vector<double> GBDT::PredictRaw(const double* value) const { std::vector<double> GBDT::PredictRaw(const double* value) const {
std::vector<double> ret(num_class_, 0.0f); std::vector<double> ret(num_class_, 0.0f);
for (int i = 0; i < num_used_model_; ++i) { for (int i = 0; i < num_iteration_for_pred_; ++i) {
for (int j = 0; j < num_class_; ++j) { for (int j = 0; j < num_class_; ++j) {
ret[j] += models_[i * num_class_ + j]->Predict(value); ret[j] += models_[i * num_class_ + j]->Predict(value);
} }
...@@ -496,7 +496,7 @@ std::vector<double> GBDT::PredictRaw(const double* value) const { ...@@ -496,7 +496,7 @@ std::vector<double> GBDT::PredictRaw(const double* value) const {
std::vector<double> GBDT::Predict(const double* value) const { std::vector<double> GBDT::Predict(const double* value) const {
std::vector<double> ret(num_class_, 0.0f); std::vector<double> ret(num_class_, 0.0f);
for (int i = 0; i < num_used_model_; ++i) { for (int i = 0; i < num_iteration_for_pred_; ++i) {
for (int j = 0; j < num_class_; ++j) { for (int j = 0; j < num_class_; ++j) {
ret[j] += models_[i * num_class_ + j]->Predict(value); ret[j] += models_[i * num_class_ + j]->Predict(value);
} }
...@@ -512,7 +512,7 @@ std::vector<double> GBDT::Predict(const double* value) const { ...@@ -512,7 +512,7 @@ std::vector<double> GBDT::Predict(const double* value) const {
std::vector<int> GBDT::PredictLeafIndex(const double* value) const { std::vector<int> GBDT::PredictLeafIndex(const double* value) const {
std::vector<int> ret; std::vector<int> ret;
for (int i = 0; i < num_used_model_; ++i) { for (int i = 0; i < num_iteration_for_pred_; ++i) {
for (int j = 0; j < num_class_; ++j) { for (int j = 0; j < num_class_; ++j) {
ret.push_back(models_[i * num_class_ + j]->PredictLeafIndex(value)); ret.push_back(models_[i * num_class_ + j]->PredictLeafIndex(value));
} }
......
...@@ -73,7 +73,7 @@ public: ...@@ -73,7 +73,7 @@ public:
* \param result used to store prediction result, should allocate memory before call this function * \param result used to store prediction result, should allocate memory before call this function
* \param out_len lenght of returned score * \param out_len lenght of returned score
*/ */
void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) const override; void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) override;
/*! /*!
* \brief Predtion for one record without sigmoid transformation * \brief Predtion for one record without sigmoid transformation
...@@ -98,11 +98,11 @@ public: ...@@ -98,11 +98,11 @@ public:
/*! /*!
* \brief save model to file * \brief save model to file
* \param num_used_model number of model that want to save, -1 means save all * \param num_iteration -1 means save all
* \param is_finish is training finished or not * \param is_finish is training finished or not
* \param filename filename that want to save to * \param filename filename that want to save to
*/ */
virtual void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) override; virtual void SaveModelToFile(int num_iteration, bool is_finish, const char* filename) override;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
*/ */
...@@ -119,11 +119,12 @@ public: ...@@ -119,11 +119,12 @@ public:
*/ */
inline int LabelIdx() const override { return label_idx_; } inline int LabelIdx() const override { return label_idx_; }
/*! /*!
* \brief Get number of weak sub-models * \brief Get number of weak sub-models
* \return Number of weak sub-models * \return Number of weak sub-models
*/ */
inline int NumberOfSubModels() const override { return static_cast<int>(models_.size()); } inline int NumberOfTotalModel() const override { return static_cast<int>(models_.size()); }
/*! /*!
* \brief Get number of classes * \brief Get number of classes
...@@ -132,11 +133,13 @@ public: ...@@ -132,11 +133,13 @@ public:
inline int NumberOfClasses() const override { return num_class_; } inline int NumberOfClasses() const override { return num_class_; }
/*! /*!
* \brief Set number of used model for prediction * \brief Set number of iterations for prediction
*/ */
inline void SetNumUsedModel(int num_used_model) { inline void SetNumIterationForPred(int num_iteration) override {
if (num_used_model >= 0) { if (num_iteration > 0) {
num_used_model_ = static_cast<int>(num_used_model / num_class_); num_iteration_for_pred_ = num_iteration;
} else {
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
} }
} }
...@@ -236,7 +239,7 @@ protected: ...@@ -236,7 +239,7 @@ protected:
/*! \brief File to write models */ /*! \brief File to write models */
std::ofstream model_output_file_; std::ofstream model_output_file_;
/*! \brief number of used model */ /*! \brief number of used model */
int num_used_model_; int num_iteration_for_pred_;
/*! \brief Shrinkage rate for one iteration */ /*! \brief Shrinkage rate for one iteration */
double shrinkage_rate_; double shrinkage_rate_;
}; };
......
...@@ -95,8 +95,8 @@ public: ...@@ -95,8 +95,8 @@ public:
return boosting_->TrainOneIter(gradients, hessians, false); return boosting_->TrainOneIter(gradients, hessians, false);
} }
void PrepareForPrediction(int num_used_model, int predict_type) { void PrepareForPrediction(int num_iteration, int predict_type) {
boosting_->SetNumUsedModel(num_used_model); boosting_->SetNumIterationForPred(num_iteration);
bool is_predict_leaf = false; bool is_predict_leaf = false;
bool is_raw_score = false; bool is_raw_score = false;
if (predict_type == C_API_PREDICT_LEAF_INDEX) { if (predict_type == C_API_PREDICT_LEAF_INDEX) {
...@@ -109,6 +109,10 @@ public: ...@@ -109,6 +109,10 @@ public:
predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf)); predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
} }
void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
boosting_->GetPredictAt(data_idx, out_result, out_len);
}
std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) { std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
return predictor_->GetPredictFunction()(features); return predictor_->GetPredictFunction()(features);
} }
...@@ -117,8 +121,8 @@ public: ...@@ -117,8 +121,8 @@ public:
predictor_->Predict(data_filename, result_filename, data_has_header); predictor_->Predict(data_filename, result_filename, data_has_header);
} }
void SaveModelToFile(int num_used_model, const char* filename) { void SaveModelToFile(int num_iteration, const char* filename) {
boosting_->SaveModelToFile(num_used_model, true, filename); boosting_->SaveModelToFile(num_iteration, true, filename);
} }
int GetEvalCounts() const { int GetEvalCounts() const {
...@@ -129,11 +133,18 @@ public: ...@@ -129,11 +133,18 @@ public:
return ret; return ret;
} }
int GetEvalNames(const char*** out_strs) const { int GetEvalNames(char** out_strs) const {
int idx = 0; int idx = 0;
for (const auto& metric : train_metric_) { for (const auto& metric : train_metric_) {
for (const auto& name : metric->GetName()) { for (const auto& name : metric->GetName()) {
*(out_strs[idx++]) = name.c_str(); int j = 0;
auto name_cstr = name.c_str();
while (name_cstr[j] != '\0') {
out_strs[idx][j] = name_cstr[j];
++j;
}
out_strs[idx][j] = '\0';
++idx;
} }
} }
return idx; return idx;
...@@ -141,10 +152,6 @@ public: ...@@ -141,10 +152,6 @@ public:
const Boosting* GetBoosting() const { return boosting_.get(); } const Boosting* GetBoosting() const { return boosting_.get(); }
const float* GetTrainingScore(int* out_len) const { return boosting_->GetTrainingScore(out_len); }
const inline int NumberOfClasses() const { return boosting_->NumberOfClasses(); }
private: private:
std::unique_ptr<Boosting> boosting_; std::unique_ptr<Boosting> boosting_;
...@@ -449,9 +456,12 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, ...@@ -449,9 +456,12 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
DllExport int LGBM_BoosterCreateFromModelfile( DllExport int LGBM_BoosterCreateFromModelfile(
const char* filename, const char* filename,
int64_t* num_total_model,
BoosterHandle* out) { BoosterHandle* out) {
API_BEGIN(); API_BEGIN();
*out = new Booster(filename); auto ret = std::unique_ptr<Booster>(new Booster(filename));
*num_total_model = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel());
*out = ret.release();
API_END(); API_END();
} }
...@@ -461,6 +471,13 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) { ...@@ -461,6 +471,13 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) {
API_END(); API_END();
} }
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetBoosting()->NumberOfClasses();
API_END();
}
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) { DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
...@@ -501,7 +518,7 @@ DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) ...@@ -501,7 +518,7 @@ DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len)
* \brief Get number of eval * \brief Get number of eval
* \return total number of eval result * \return total number of eval result
*/ */
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, const char*** out_strs) { DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetEvalNames(out_strs); *out_len = ref_booster->GetEvalNames(out_strs);
...@@ -524,39 +541,27 @@ DllExport int LGBM_BoosterGetEval(BoosterHandle handle, ...@@ -524,39 +541,27 @@ DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
API_END(); API_END();
} }
DllExport int LGBM_BoosterGetTrainingScore(BoosterHandle handle,
int64_t* out_len,
const float** out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
int len = 0;
*out_result = ref_booster->GetTrainingScore(&len);
*out_len = static_cast<int64_t>(len);
API_END();
}
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle, DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
int data, int data,
int64_t* out_len, int64_t* out_len,
float* out_result) { float* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting();
int len = 0; int len = 0;
boosting->GetPredictAt(data, out_result, &len); ref_booster->GetPredictAt(data, out_result, &len);
*out_len = static_cast<int64_t>(len); *out_len = static_cast<int64_t>(len);
API_END(); API_END();
} }
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
int predict_type,
int64_t n_used_trees,
int data_has_header,
const char* data_filename, const char* data_filename,
int data_has_header,
int predict_type,
int64_t num_iteration,
const char* result_filename) { const char* result_filename) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type); ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
bool bool_data_has_header = data_has_header > 0 ? true : false; bool bool_data_has_header = data_has_header > 0 ? true : false;
ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header); ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
API_END(); API_END();
...@@ -572,23 +577,32 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -572,23 +577,32 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t nelem, int64_t nelem,
int64_t, int64_t,
int predict_type, int predict_type,
int64_t n_used_trees, int64_t num_iteration,
double* out_result) { int64_t* out_len,
float* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type); ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int num_class = ref_booster->NumberOfClasses(); int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
if (predict_type == C_API_PREDICT_LEAF_INDEX) {
if (num_iteration > 0) {
num_preb_in_one_row *= static_cast<int>(num_iteration);
} else {
num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
}
}
int nrow = static_cast<int>(nindptr - 1); int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < nrow; ++i) { for (int i = 0; i < nrow; ++i) {
auto one_row = get_row_fun(i); auto one_row = get_row_fun(i);
auto predicton_result = ref_booster->Predict(one_row); auto predicton_result = ref_booster->Predict(one_row);
for (int j = 0; j < num_class; ++j) { for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
out_result[i * num_class + j] = predicton_result[j]; out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
} }
} }
*out_len = nrow * num_preb_in_one_row;
API_END(); API_END();
} }
...@@ -599,31 +613,40 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -599,31 +613,40 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
int32_t ncol, int32_t ncol,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int64_t n_used_trees, int64_t num_iteration,
double* out_result) { int64_t* out_len,
float* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type); ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
int num_class = ref_booster->NumberOfClasses(); int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
if (predict_type == C_API_PREDICT_LEAF_INDEX) {
if (num_iteration > 0) {
num_preb_in_one_row *= static_cast<int>(num_iteration);
} else {
num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
}
}
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < nrow; ++i) { for (int i = 0; i < nrow; ++i) {
auto one_row = get_row_fun(i); auto one_row = get_row_fun(i);
auto predicton_result = ref_booster->Predict(one_row); auto predicton_result = ref_booster->Predict(one_row);
for (int j = 0; j < num_class; ++j) { for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
out_result[i * num_class + j] = predicton_result[j]; out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
} }
} }
*out_len = nrow * num_preb_in_one_row;
API_END(); API_END();
} }
DllExport int LGBM_BoosterSaveModel(BoosterHandle handle, DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_used_model, int num_iteration,
const char* filename) { const char* filename) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->SaveModelToFile(num_used_model, filename); ref_booster->SaveModelToFile(num_iteration, filename);
API_END(); API_END();
} }
......
...@@ -183,7 +183,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -183,7 +183,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "data_random_seed", &data_random_seed); GetInt(params, "data_random_seed", &data_random_seed);
GetString(params, "data", &data_filename); GetString(params, "data", &data_filename);
GetInt(params, "verbose", &verbosity); GetInt(params, "verbose", &verbosity);
GetInt(params, "num_model_predict", &num_model_predict); GetInt(params, "num_iteration_predict", &num_iteration_predict);
GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt); GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
GetBool(params, "is_pre_partition", &is_pre_partition); GetBool(params, "is_pre_partition", &is_pre_partition);
GetBool(params, "is_enable_sparse", &is_enable_sparse); GetBool(params, "is_enable_sparse", &is_enable_sparse);
......
...@@ -190,14 +190,16 @@ def test_booster(): ...@@ -190,14 +190,16 @@ def test_booster():
test_free_dataset(train) test_free_dataset(train)
test_free_dataset(test[0]) test_free_dataset(test[0])
booster2 = ctypes.c_void_p() booster2 = ctypes.c_void_p()
LIB.LGBM_BoosterCreateFromModelfile(c_str('model.txt'), ctypes.byref(booster2)) num_total_model = ctypes.c_long()
LIB.LGBM_BoosterCreateFromModelfile(c_str('model.txt'), ctypes.byref(num_total_model), ctypes.byref(booster2))
data = [] data = []
inp = open('../../examples/binary_classification/binary.test', 'r') inp = open('../../examples/binary_classification/binary.test', 'r')
for line in inp.readlines(): for line in inp.readlines():
data.append( [float(x) for x in line.split('\t')[1:]] ) data.append( [float(x) for x in line.split('\t')[1:]] )
inp.close() inp.close()
mat = np.array(data) mat = np.array(data)
preb = np.zeros(( mat.shape[0],1 ), dtype=np.float64) preb = np.zeros(mat.shape[0], dtype=np.float32)
num_preb = ctypes.c_long()
data = np.array(mat.reshape(mat.size), copy=False) data = np.array(mat.reshape(mat.size), copy=False)
LIB.LGBM_BoosterPredictForMat(booster2, LIB.LGBM_BoosterPredictForMat(booster2,
data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)), data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
...@@ -207,8 +209,9 @@ def test_booster(): ...@@ -207,8 +209,9 @@ def test_booster():
1, 1,
1, 1,
50, 50,
ctypes.byref(num_preb),
preb.ctypes.data_as(ctypes.POINTER(ctypes.c_double))) preb.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
LIB.LGBM_BoosterPredictForFile(booster2, 1, 50, 0, c_str('../../examples/binary_classification/binary.test'), c_str('preb.txt')) LIB.LGBM_BoosterPredictForFile(booster2,c_str('../../examples/binary_classification/binary.test'),0 , 0, 50, c_str('preb.txt'))
LIB.LGBM_BoosterFree(booster2) LIB.LGBM_BoosterFree(booster2)
test_dataset() test_dataset()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment