Commit e161a746 authored by Guolin Ke's avatar Guolin Ke
Browse files

Add Interface of Booster

parent 792826e1
...@@ -50,6 +50,8 @@ public: ...@@ -50,6 +50,8 @@ public:
virtual const score_t* GetTrainingScore(data_size_t* out_len) const = 0; virtual const score_t* GetTrainingScore(data_size_t* out_len) const = 0;
virtual void GetPredict(int data_idx, score_t* out_result, data_size_t* out_len) const = 0;
/*! /*!
* \brief Prediction for one record, not sigmoid transform * \brief Prediction for one record, not sigmoid transform
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
...@@ -75,7 +77,7 @@ public: ...@@ -75,7 +77,7 @@ public:
/*! /*!
* \brief save model to file * \brief save model to file
*/ */
virtual void SaveModelToFile(bool is_finish, const char* filename) = 0; virtual void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) = 0;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
......
...@@ -262,7 +262,7 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, ...@@ -262,7 +262,7 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
DllExport int LGBM_BoosterEval(BoosterHandle handle, DllExport int LGBM_BoosterEval(BoosterHandle handle,
int data, int data,
uint64_t* out_len, uint64_t* out_len,
double* out_results); float* out_results);
/*! /*!
* \brief get raw score for training data, used to calculate gradients outside * \brief get raw score for training data, used to calculate gradients outside
...@@ -287,7 +287,7 @@ this can be used to support customized eval function ...@@ -287,7 +287,7 @@ this can be used to support customized eval function
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle, DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
int data, int data,
uint64_t* out_len, uint64_t* out_len,
const float** out_result); float* out_result);
/*! /*!
* \brief make prediction for an new data set * \brief make prediction for an new data set
...@@ -319,36 +319,6 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -319,36 +319,6 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
uint64_t n_used_trees, uint64_t n_used_trees,
double* out_result); double* out_result);
/*!
* \brief make prediction for an new data set
* \param handle handle
* \param col_ptr pointer to col headers
* \param indices findex
* \param data fvalue
* \param float_type 0:float_32 1:float64
* \param ncol_ptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data
* \param predict_type
* 0:raw score
* 1:with transform(if needed)
* 2:leaf index
* \param n_used_trees number of used tree
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
const int32_t* col_ptr,
const int32_t* indices,
const void* data,
int float_type,
uint64_t ncol_ptr,
uint64_t nelem,
uint64_t num_row,
int predict_type,
uint64_t n_used_trees,
double* out_result);
/*! /*!
* \brief make prediction for an new data set * \brief make prediction for an new data set
* \param handle handle * \param handle handle
...@@ -356,6 +326,7 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -356,6 +326,7 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
* \param float_type 0:float_32 1:float64 * \param float_type 0:float_32 1:float64
* \param nrow number of rows * \param nrow number of rows
* \param ncol number columns * \param ncol number columns
* \param is_row_major 1 for row major, 0 for column major
* \param predict_type * \param predict_type
* 0:raw score * 0:raw score
* 1:with transform(if needed) * 1:with transform(if needed)
...@@ -369,6 +340,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -369,6 +340,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
int float_type, int float_type,
int32_t nrow, int32_t nrow,
int32_t ncol, int32_t ncol,
int is_row_major,
int predict_type, int predict_type,
uint64_t n_used_trees, uint64_t n_used_trees,
double* out_result); double* out_result);
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
#include <LightGBM/utils/common.h> #include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h> #include <LightGBM/utils/log.h>
#include <LightGBM/meta.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -94,7 +96,7 @@ public: ...@@ -94,7 +96,7 @@ public:
std::string output_result = "LightGBM_predict_result.txt"; std::string output_result = "LightGBM_predict_result.txt";
std::string input_model = ""; std::string input_model = "";
int verbosity = 1; int verbosity = 1;
int num_model_predict = -1; int num_model_predict = NO_LIMIT;
bool is_pre_partition = false; bool is_pre_partition = false;
bool is_enable_sparse = true; bool is_enable_sparse = true;
bool use_two_round_loading = false; bool use_two_round_loading = false;
...@@ -157,12 +159,12 @@ public: ...@@ -157,12 +159,12 @@ public:
int feature_fraction_seed = 2; int feature_fraction_seed = 2;
double feature_fraction = 1.0f; double feature_fraction = 1.0f;
// max cache size(unit:MB) for historical histogram. < 0 means not limit // max cache size(unit:MB) for historical histogram. < 0 means not limit
double histogram_pool_size = -1.0f; double histogram_pool_size = NO_LIMIT;
// max depth of tree model. // max depth of tree model.
// Still grow tree by leaf-wise, but limit the max depth to avoid over-fitting // Still grow tree by leaf-wise, but limit the max depth to avoid over-fitting
// And the max leaves will be min(num_leaves, pow(2, max_depth - 1)) // And the max leaves will be min(num_leaves, pow(2, max_depth - 1))
// max_depth < 0 means not limit // max_depth < 0 means not limit
int max_depth = -1; int max_depth = NO_LIMIT;
void Set(const std::unordered_map<std::string, std::string>& params) override; void Set(const std::unordered_map<std::string, std::string>& params) override;
}; };
......
...@@ -234,9 +234,6 @@ public: ...@@ -234,9 +234,6 @@ public:
static Parser* CreateParser(const char* filename, bool has_header, int num_features, int label_idx); static Parser* CreateParser(const char* filename, bool has_header, int num_features, int label_idx);
}; };
using PredictFunction =
std::function<std::vector<double>(const std::vector<std::pair<int, double>>&)>;
/*! \brief The main class of data set, /*! \brief The main class of data set,
* which are used to traning or validation * which are used to traning or validation
*/ */
......
...@@ -59,9 +59,9 @@ private: ...@@ -59,9 +59,9 @@ private:
/*! \brief index of label column */ /*! \brief index of label column */
int label_idx_ = 0; int label_idx_ = 0;
/*! \brief index of weight column */ /*! \brief index of weight column */
int weight_idx_ = -1; int weight_idx_ = NO_SPECIFIC;
/*! \brief index of group column */ /*! \brief index of group column */
int group_idx_ = -1; int group_idx_ = NO_SPECIFIC;
/*! \brief Mapper from real feature index to used index*/ /*! \brief Mapper from real feature index to used index*/
std::unordered_set<int> ignore_features_; std::unordered_set<int> ignore_features_;
/*! \brief store feature names */ /*! \brief store feature names */
......
...@@ -25,6 +25,12 @@ std::vector<const T*> ConstPtrInVectorWarpper(std::vector<T*> input) { ...@@ -25,6 +25,12 @@ std::vector<const T*> ConstPtrInVectorWarpper(std::vector<T*> input) {
using ReduceFunction = std::function<void(const char*, char*, int)>; using ReduceFunction = std::function<void(const char*, char*, int)>;
using PredictFunction =
std::function<std::vector<double>(const std::vector<std::pair<int, double>>&)>;
#define NO_LIMIT (-1)
#define NO_SPECIFIC (-1)
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_META_H_ #endif // LightGBM_META_H_
...@@ -125,11 +125,8 @@ void Application::LoadData() { ...@@ -125,11 +125,8 @@ void Application::LoadData() {
Predictor* predictor = nullptr; Predictor* predictor = nullptr;
// need to continue training // need to continue training
if (boosting_->NumberOfSubModels() > 0) { if (boosting_->NumberOfSubModels() > 0) {
predictor = new Predictor(boosting_, config_.io_config.is_raw_score, config_.predict_leaf_index); predictor = new Predictor(boosting_, true, false);
predict_fun = predict_fun = predictor->GetPredictFunction();
[&predictor](const std::vector<std::pair<int, double>>& features) {
return predictor->PredictRawOneLine(features);
};
} }
// sync up random seed for data partition // sync up random seed for data partition
...@@ -244,11 +241,11 @@ void Application::Train() { ...@@ -244,11 +241,11 @@ void Application::Train() {
// output used time per iteration // output used time per iteration
Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double, Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double,
std::milli>(end_time - start_time) * 1e-3, iter + 1); std::milli>(end_time - start_time) * 1e-3, iter + 1);
boosting_->SaveModelToFile(is_finished, config_.io_config.output_model.c_str()); boosting_->SaveModelToFile(NO_LIMIT, is_finished, config_.io_config.output_model.c_str());
} }
is_finished = true; is_finished = true;
// save model to file // save model to file
boosting_->SaveModelToFile(is_finished, config_.io_config.output_model.c_str()); boosting_->SaveModelToFile(NO_LIMIT, is_finished, config_.io_config.output_model.c_str());
Log::Info("Finished training"); Log::Info("Finished training");
} }
......
...@@ -41,6 +41,28 @@ public: ...@@ -41,6 +41,28 @@ public:
for (int i = 0; i < num_threads_; ++i) { for (int i = 0; i < num_threads_; ++i) {
features_[i] = new double[num_features_]; features_[i] = new double[num_features_];
} }
if (is_predict_leaf_index_) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features);
// get result for leaf index
auto result = boosting_->PredictLeafIndex(features_[tid]);
return std::vector<double>(result.begin(), result.end());
};
} else {
if (is_raw_score_) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features);
// get result without sigmoid transformation
return boosting_->PredictRaw(features_[tid]);
};
} else {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features);
return boosting_->Predict(features_[tid]);
};
}
}
} }
/*! /*!
* \brief Destructor * \brief Destructor
...@@ -54,37 +76,8 @@ public: ...@@ -54,37 +76,8 @@ public:
} }
} }
/*! inline const PredictFunction& GetPredictFunction() {
* \brief prediction for one record, only raw result (without sigmoid transformation) return predict_fun_;
* \param features Feature for this record
* \return Prediction result
*/
std::vector<double> PredictRawOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features);
// get result without sigmoid transformation
return boosting_->PredictRaw(features_[tid]);
}
/*!
* \brief prediction for one record, only raw result (without sigmoid transformation)
* \param features Feature for this record
* \return Predictied leaf index
*/
std::vector<int> PredictLeafIndexOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features);
// get result for leaf index
return boosting_->PredictLeafIndex(features_[tid]);
}
/*!
* \brief prediction for one record, will use sigmoid transformation if needed (only enabled for binary classification noe)
* \param features Feature of this record
* \return Prediction result
*/
std::vector<double> PredictOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed
return boosting_->Predict(features_[tid]);
} }
/*! /*!
...@@ -119,26 +112,8 @@ public: ...@@ -119,26 +112,8 @@ public:
parser->ParseOneLine(buffer, feature, &tmp_label); parser->ParseOneLine(buffer, feature, &tmp_label);
}; };
std::function<std::string(const std::vector<std::pair<int, double>>&)> predict_fun;
if (is_predict_leaf_index_) {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){
return Common::Join<int>(PredictLeafIndexOneLine(features), '\t');
};
}
else {
if (is_raw_score_) {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){
return Common::Join<double>(PredictRawOneLine(features), '\t');
};
}
else {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){
return Common::Join<double>(PredictOneLine(features), '\t');
};
}
}
std::function<void(data_size_t, const std::vector<std::string>&)> process_fun = std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
[this, &parser_fun, &predict_fun, &result_file] [this, &parser_fun, &result_file]
(data_size_t, const std::vector<std::string>& lines) { (data_size_t, const std::vector<std::string>& lines) {
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, double>> oneline_features;
std::vector<std::string> pred_result(lines.size(), ""); std::vector<std::string> pred_result(lines.size(), "");
...@@ -148,7 +123,7 @@ public: ...@@ -148,7 +123,7 @@ public:
// parser // parser
parser_fun(lines[i].c_str(), &oneline_features); parser_fun(lines[i].c_str(), &oneline_features);
// predict // predict
pred_result[i] = predict_fun(oneline_features); pred_result[i] = Common::Join<double>(predict_fun_(oneline_features), '\t');
} }
for (size_t i = 0; i < pred_result.size(); ++i) { for (size_t i = 0; i < pred_result.size(); ++i) {
...@@ -187,6 +162,8 @@ private: ...@@ -187,6 +162,8 @@ private:
int num_threads_; int num_threads_;
/*! \brief True if output leaf index instead of prediction score */ /*! \brief True if output leaf index instead of prediction score */
bool is_predict_leaf_index_; bool is_predict_leaf_index_;
/*! \brief function for prediction */
PredictFunction predict_fun_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -293,25 +293,68 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const { ...@@ -293,25 +293,68 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const {
return ret; return ret;
} }
/*! \brief Get prediction result */ /*! \brief Get training scores result */
const score_t* GBDT::GetTrainingScore(data_size_t* out_len) const { const score_t* GBDT::GetTrainingScore(data_size_t* out_len) const {
*out_len = train_score_updater_->num_data() * num_class_; *out_len = train_score_updater_->num_data() * num_class_;
return train_score_updater_->score(); return train_score_updater_->score();
} }
void GBDT::GetPredict(int data_idx, score_t* out_result, data_size_t* out_len) const {
CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_metrics_.size()));
std::vector<double> ret;
const score_t* raw_scores = nullptr;
data_size_t num_data = 0;
if (data_idx == 0) {
raw_scores = train_score_updater_->score();
num_data = train_score_updater_->num_data();
} else {
auto used_idx = data_idx - 1;
raw_scores = valid_score_updater_[used_idx]->score();
num_data = valid_score_updater_[used_idx]->num_data();
}
*out_len = num_data * num_class_;
if (num_class_ > 1) {
#pragma omp parallel for schedule(guided)
for (data_size_t i = 0; i < num_data; ++i) {
std::vector<double> tmp_result;
for (int j = 0; j < num_class_; ++j) {
tmp_result.push_back(raw_scores[j * num_data + i]);
}
Common::Softmax(&tmp_result);
for (int j = 0; j < num_class_; ++j) {
out_result[j * num_data + i] = static_cast<score_t>(tmp_result[i]);
}
}
} else if(sigmoid_ > 0){
#pragma omp parallel for schedule(guided)
for (data_size_t i = 0; i < num_data; ++i) {
out_result[i] = static_cast<score_t>(1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * raw_scores[i])));
}
} else {
#pragma omp parallel for schedule(guided)
for (data_size_t i = 0; i < num_data; ++i) {
out_result[i] = raw_scores[i];
}
}
}
void GBDT::Boosting() { void GBDT::Boosting() {
if (object_function_ == nullptr) { if (object_function_ == nullptr) {
Log::Fatal("No object function provided"); Log::Fatal("No object function provided");
} }
// objective function will calculate gradients and hessians // objective function will calculate gradients and hessians
int num_score = 0;
object_function_-> object_function_->
GetGradients(train_score_updater_->score(), gradients_, hessians_); GetGradients(GetTrainingScore(&num_score), gradients_, hessians_);
} }
void GBDT::SaveModelToFile(bool is_finish, const char* filename) { void GBDT::SaveModelToFile(int num_used_model, bool is_finish, const char* filename) {
// first time to this function, open file // first time to this function, open file
if (saved_model_size_ == -1) { if (saved_model_size_ < 0) {
model_output_file_.open(filename); model_output_file_.open(filename);
// output model type // output model type
model_output_file_ << "gbdt" << std::endl; model_output_file_ << "gbdt" << std::endl;
...@@ -330,7 +373,12 @@ void GBDT::SaveModelToFile(bool is_finish, const char* filename) { ...@@ -330,7 +373,12 @@ void GBDT::SaveModelToFile(bool is_finish, const char* filename) {
if (!model_output_file_.is_open()) { if (!model_output_file_.is_open()) {
return; return;
} }
int rest = static_cast<int>(models_.size()) - early_stopping_round_ * num_class_; if (num_used_model_ == NO_LIMIT) {
num_used_model = static_cast<int>(models_.size());
} else {
num_used_model = num_used_model * num_class_;
}
int rest = num_used_model - early_stopping_round_ * num_class_;
// output tree models // output tree models
for (int i = saved_model_size_; i < rest; ++i) { for (int i = saved_model_size_; i < rest; ++i) {
model_output_file_ << "Tree=" << i << std::endl; model_output_file_ << "Tree=" << i << std::endl;
...@@ -342,7 +390,7 @@ void GBDT::SaveModelToFile(bool is_finish, const char* filename) { ...@@ -342,7 +390,7 @@ void GBDT::SaveModelToFile(bool is_finish, const char* filename) {
model_output_file_.flush(); model_output_file_.flush();
// training finished, can close file // training finished, can close file
if (is_finish) { if (is_finish) {
for (int i = saved_model_size_; i < static_cast<int>(models_.size()); ++i) { for (int i = saved_model_size_; i < num_used_model; ++i) {
model_output_file_ << "Tree=" << i << std::endl; model_output_file_ << "Tree=" << i << std::endl;
model_output_file_ << models_[i]->ToString() << std::endl; model_output_file_ << models_[i]->ToString() << std::endl;
} }
......
...@@ -52,6 +52,8 @@ public: ...@@ -52,6 +52,8 @@ public:
/*! \brief Get prediction result */ /*! \brief Get prediction result */
const score_t* GetTrainingScore(data_size_t* out_len) const override; const score_t* GetTrainingScore(data_size_t* out_len) const override;
void GetPredict(int data_idx, score_t* out_result, data_size_t* out_len) const override;
/*! /*!
* \brief Predtion for one record without sigmoid transformation * \brief Predtion for one record without sigmoid transformation
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
...@@ -77,7 +79,7 @@ public: ...@@ -77,7 +79,7 @@ public:
* \brief Serialize models by string * \brief Serialize models by string
* \return String output of tranined model * \return String output of tranined model
*/ */
void SaveModelToFile(bool is_finish, const char* filename) override; void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) override;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
*/ */
......
...@@ -16,19 +16,21 @@ ...@@ -16,19 +16,21 @@
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include "./application/predictor.hpp"
namespace LightGBM { namespace LightGBM {
class Booster { class Booster {
public: public:
explicit Booster(const char* filename): explicit Booster(const char* filename):
boosting_(Boosting::CreateBoosting(filename)) { boosting_(Boosting::CreateBoosting(filename)), predictor_(nullptr) {
} }
Booster(const Dataset* train_data, Booster(const Dataset* train_data,
std::vector<const Dataset*> valid_data, std::vector<const Dataset*> valid_data,
std::vector<std::string> valid_names, std::vector<std::string> valid_names,
const char* parameters) const char* parameters)
:train_data_(train_data), valid_datas_(valid_data) { :train_data_(train_data), valid_datas_(valid_data), predictor_(nullptr) {
config_.LoadFromString(parameters); config_.LoadFromString(parameters);
// create boosting // create boosting
if (config_.io_config.input_model.size() > 0) { if (config_.io_config.input_model.size() > 0) {
...@@ -87,6 +89,7 @@ public: ...@@ -87,6 +89,7 @@ public:
valid_metrics_.clear(); valid_metrics_.clear();
if (boosting_ != nullptr) { delete boosting_; } if (boosting_ != nullptr) { delete boosting_; }
if (objective_fun_ != nullptr) { delete objective_fun_; } if (objective_fun_ != nullptr) { delete objective_fun_; }
if (predictor_ != nullptr) { delete predictor_; }
} }
bool TrainOneIter() { bool TrainOneIter() {
...@@ -99,13 +102,31 @@ public: ...@@ -99,13 +102,31 @@ public:
void PrepareForPrediction(int num_used_model, int predict_type) { void PrepareForPrediction(int num_used_model, int predict_type) {
boosting_->SetNumUsedModel(num_used_model); boosting_->SetNumUsedModel(num_used_model);
if (predictor_ != nullptr) { delete predictor_; }
bool is_predict_leaf = false;
bool is_raw_score = false;
if (predict_type == 2) {
is_predict_leaf = true;
} else if (predict_type == 1) {
is_raw_score = false;
} else {
is_raw_score = true;
}
predictor_ = new Predictor(boosting_, is_raw_score, is_predict_leaf);
}
std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
return predictor_->GetPredictFunction()(features);
} }
void SaveModelToFile(int num_used_model, const char* filename) {
boosting_->SaveModelToFile(num_used_model, true, filename);
}
const Boosting* GetBoosting() const { return boosting_; } const Boosting* GetBoosting() const { return boosting_; }
private: const inline int NumberOfClass() const { return boosting_->NumberOfClass(); }
std::function<std::vector<double>(const std::vector<std::pair<int, double>>&)> predict_fun; private:
Boosting* boosting_; Boosting* boosting_;
/*! \brief All configs */ /*! \brief All configs */
...@@ -120,6 +141,8 @@ private: ...@@ -120,6 +141,8 @@ private:
std::vector<std::vector<Metric*>> valid_metrics_; std::vector<std::vector<Metric*>> valid_metrics_;
/*! \brief Training objective function */ /*! \brief Training objective function */
ObjectiveFunction* objective_fun_; ObjectiveFunction* objective_fun_;
/*! \brief Using predictor for prediction task */
Predictor* predictor_;
}; };
...@@ -421,14 +444,14 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, ...@@ -421,14 +444,14 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
DllExport int LGBM_BoosterEval(BoosterHandle handle, DllExport int LGBM_BoosterEval(BoosterHandle handle,
int data, int data,
uint64_t* out_len, uint64_t* out_len,
double* out_results) { float* out_results) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting(); auto boosting = ref_booster->GetBoosting();
auto result_buf = boosting->GetEvalAt(data); auto result_buf = boosting->GetEvalAt(data);
*out_len = static_cast<uint64_t>(result_buf.size()); *out_len = static_cast<uint64_t>(result_buf.size());
for (size_t i = 0; i < result_buf.size(); ++i) { for (size_t i = 0; i < result_buf.size(); ++i) {
(out_results)[i] = result_buf[i]; (out_results)[i] = static_cast<float>(result_buf[i]);
} }
return 0; return 0;
} }
...@@ -446,6 +469,19 @@ DllExport int LGBM_BoosterGetScore(BoosterHandle handle, ...@@ -446,6 +469,19 @@ DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
return 0; return 0;
} }
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
int data,
uint64_t* out_len,
float* out_result) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting();
int len = 0;
boosting->GetPredict(data, out_result, &len);
*out_len = static_cast<uint64_t>(len);
return 0;
}
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
const int32_t* indptr, const int32_t* indptr,
const int32_t* indices, const int32_t* indices,
...@@ -453,63 +489,53 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -453,63 +489,53 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int float_type, int float_type,
uint64_t nindptr, uint64_t nindptr,
uint64_t nelem, uint64_t nelem,
uint64_t num_col, uint64_t,
int predict_type, int predict_type,
uint64_t n_used_trees, uint64_t n_used_trees,
double* out_result); double* out_result) {
/*! Booster* ref_booster = reinterpret_cast<Booster*>(handle);
* \brief make prediction for an new data set ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
* \param handle handle
* \param col_ptr pointer to col headers auto get_row_fun = Common::RowFunctionFromCSR(indptr, indices, data, float_type, nindptr, nelem);
* \param indices findex int num_class = ref_booster->NumberOfClass();
* \param data fvalue int nrow = static_cast<int>(nindptr - 1);
* \param nindptr number of rows in the matix + 1 #pragma omp parallel for schedule(guided)
* \param nelem number of nonzero elements in the matrix for (int i = 0; i < nrow; ++i) {
* \param num_row number of rows; when it's set to 0, then guess from data auto one_row = get_row_fun(i);
* \param predict_type auto predicton_result = ref_booster->Predict(one_row);
* 0:raw score for (int j = 0; j < num_class; j++) {
* 1:with sigmoid transform(if needed) out_result[i * num_class + j] = predicton_result[j];
* 2:leaf index }
* \param n_used_trees number of used tree }
* \param out_result used to set a pointer to array return 0;
* \return 0 when success, -1 when failure happens }
*/
DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
const int32_t* col_ptr,
const int32_t* indices,
const void* data,
int float_type,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_row,
int predict_type,
uint64_t n_used_trees,
double* out_result);
/*!
* \brief make prediction for an new data set
* \param handle handle
* \param data pointer to the data space
* \param nrow number of rows
* \param ncol number columns
* \param missing which value to represent missing value
* \param predict_type
* 0:raw score
* 1:with sigmoid transform(if needed)
* 2:leaf index
* \param n_used_trees number of used tree
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
const void* data, const void* data,
int float_type, int float_type,
int32_t nrow, int32_t nrow,
int32_t ncol, int32_t ncol,
int is_row_major,
int predict_type, int predict_type,
uint64_t n_used_trees, uint64_t n_used_trees,
double* out_result); double* out_result) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
auto get_row_fun = Common::RowPairFunctionFromDenseMatric(data, nrow, ncol, float_type, is_row_major);
int num_class = ref_booster->NumberOfClass();
#pragma omp parallel for schedule(guided)
for (int i = 0; i < nrow; ++i) {
auto one_row = get_row_fun(i);
auto predicton_result = ref_booster->Predict(one_row);
for (int j = 0; j < num_class; j++) {
out_result[i * num_class + j] = predicton_result[j];
}
}
return 0;
}
/*! /*!
* \brief save model into file * \brief save model into file
...@@ -520,4 +546,9 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -520,4 +546,9 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
*/ */
DllExport int LGBM_BoosterSaveModel(BoosterHandle handle, DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_used_model, int num_used_model,
const char* filename); const char* filename) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->SaveModelToFile(num_used_model, filename);
return 0;
}
...@@ -437,7 +437,7 @@ Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>& ...@@ -437,7 +437,7 @@ Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>&
delete bin_mappers[i]; delete bin_mappers[i];
} }
} }
dataset->metadata_.Init(dataset->num_data_, dataset->num_class_, -1, -1); dataset->metadata_.Init(dataset->num_data_, dataset->num_class_, NO_SPECIFIC, NO_SPECIFIC);
return dataset; return dataset;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment