Commit 9fe0dea3 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

add predict_leaf_index option (#30)

* a python regression example

* add early-stopping

* fix bugs

* remove not needed files

* update easy-stopping; fix warnings; add alias

* add comments; remove useless line

* change code order

* change Print to PrintAndGetLoss: return loss

* change comment; add bracket

* change return type of PredictLeafIndex in boosting.h

* move string cast into lambda function

* move if-then outside predict_fun
parent 70f7d605
...@@ -59,7 +59,14 @@ public: ...@@ -59,7 +59,14 @@ public:
* \return Prediction result for this record * \return Prediction result for this record
*/ */
virtual double Predict(const double * feature_values) const = 0; virtual double Predict(const double * feature_values) const = 0;
/*!
* \brief Predtion for one record with leaf index
* \param feature_values Feature value on this record
* \return Predicted leaf index for this record
*/
virtual std::vector<int> PredictLeafIndex(const double * feature_values) const = 0;
/*! /*!
* \brief Serialize models by string * \brief Serialize models by string
* \return String output of tranined model * \return String output of tranined model
......
...@@ -195,6 +195,7 @@ public: ...@@ -195,6 +195,7 @@ public:
int num_threads = 0; int num_threads = 0;
bool is_parallel = false; bool is_parallel = false;
bool is_parallel_find_bin = false; bool is_parallel_find_bin = false;
bool predict_leaf_index = false;
IOConfig io_config; IOConfig io_config;
BoostingType boosting_type = BoostingType::kGBDT; BoostingType boosting_type = BoostingType::kGBDT;
BoostingConfig* boosting_config; BoostingConfig* boosting_config;
......
...@@ -75,6 +75,7 @@ public: ...@@ -75,6 +75,7 @@ public:
* \return Prediction result * \return Prediction result
*/ */
inline score_t Predict(const double* feature_values) const; inline score_t Predict(const double* feature_values) const;
inline int PredictLeafIndex(const double* feature_values) const;
/*! \brief Get Number of leaves*/ /*! \brief Get Number of leaves*/
inline int num_leaves() const { return num_leaves_; } inline int num_leaves() const { return num_leaves_; }
...@@ -141,11 +142,16 @@ private: ...@@ -141,11 +142,16 @@ private:
}; };
inline score_t Tree::Predict(const double* feature_values)const { inline score_t Tree::Predict(const double* feature_values) const {
int leaf = GetLeaf(feature_values); int leaf = GetLeaf(feature_values);
return LeafOutput(leaf); return LeafOutput(leaf);
} }
inline int Tree::PredictLeafIndex(const double* feature_values) const {
int leaf = GetLeaf(feature_values);
return leaf;
}
inline int Tree::GetLeaf(const std::vector<BinIterator*>& iterators, inline int Tree::GetLeaf(const std::vector<BinIterator*>& iterators,
data_size_t data_idx) const { data_size_t data_idx) const {
int node = 0; int node = 0;
......
...@@ -63,7 +63,7 @@ public: ...@@ -63,7 +63,7 @@ public:
static void Fatal(const char *format, ...) { static void Fatal(const char *format, ...) {
va_list val; va_list val;
va_start(val, format); va_start(val, format);
fprintf(stderr, "[LightGBM] [Fatel] "); fprintf(stderr, "[LightGBM] [Fatal] ");
vfprintf(stderr, format, val); vfprintf(stderr, format, val);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fflush(stderr); fflush(stderr);
......
...@@ -125,7 +125,7 @@ void Application::LoadData() { ...@@ -125,7 +125,7 @@ void Application::LoadData() {
if (config_.io_config.input_model.size() > 0) { if (config_.io_config.input_model.size() > 0) {
LoadModel(); LoadModel();
if (boosting_->NumberOfSubModels() > 0) { if (boosting_->NumberOfSubModels() > 0) {
predictor = new Predictor(boosting_, config_.io_config.is_sigmoid); predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index);
predict_fun = predict_fun =
[&predictor](const std::vector<std::pair<int, double>>& features) { [&predictor](const std::vector<std::pair<int, double>>& features) {
return predictor->PredictRawOneLine(features); return predictor->PredictRawOneLine(features);
...@@ -252,7 +252,7 @@ void Application::Train() { ...@@ -252,7 +252,7 @@ void Application::Train() {
void Application::Predict() { void Application::Predict() {
// create predictor // create predictor
Predictor predictor(boosting_, config_.io_config.is_sigmoid); Predictor predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index);
predictor.Predict(config_.io_config.data_filename.c_str(), config_.io_config.output_result.c_str()); predictor.Predict(config_.io_config.data_filename.c_str(), config_.io_config.output_result.c_str());
Log::Info("Finish predict."); Log::Info("Finish predict.");
} }
......
...@@ -26,9 +26,10 @@ public: ...@@ -26,9 +26,10 @@ public:
* \brief Constructor * \brief Constructor
* \param boosting Input boosting model * \param boosting Input boosting model
* \param is_sigmoid True if need to predict result with sigmoid transform(if needed, like binary classification) * \param is_sigmoid True if need to predict result with sigmoid transform(if needed, like binary classification)
* \param predict_leaf_index True if output leaf index instead of prediction score
*/ */
Predictor(const Boosting* boosting, bool is_simgoid) Predictor(const Boosting* boosting, bool is_simgoid, bool predict_leaf_index)
: is_simgoid_(is_simgoid) { : is_simgoid_(is_simgoid), predict_leaf_index(predict_leaf_index) {
boosting_ = boosting; boosting_ = boosting;
num_features_ = boosting_->MaxFeatureIdx() + 1; num_features_ = boosting_->MaxFeatureIdx() + 1;
#pragma omp parallel #pragma omp parallel
...@@ -71,6 +72,25 @@ public: ...@@ -71,6 +72,25 @@ public:
// get result without sigmoid transformation // get result without sigmoid transformation
return boosting_->PredictRaw(features_[tid]); return boosting_->PredictRaw(features_[tid]);
} }
/*!
* \brief prediction for one record, only raw result(without sigmoid transformation)
* \param features Feature for this record
* \return Predictied leaf index
*/
std::vector<int> PredictLeafIndexOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = omp_get_thread_num();
// init feature value
std::memset(features_[tid], 0, sizeof(double)*num_features_);
// put feature value
for (const auto& p : features) {
if (p.first < num_features_) {
features_[tid][p.first] = p.second;
}
}
// get result for leaf index
return boosting_->PredictLeafIndex(features_[tid]);
}
/*! /*!
* \brief prediction for one record, will use sigmoid transformation if needed(only enabled for binary classification noe) * \brief prediction for one record, will use sigmoid transformation if needed(only enabled for binary classification noe)
...@@ -133,21 +153,37 @@ public: ...@@ -133,21 +153,37 @@ public:
}; };
Log::Info("Start prediction for data %s without label", data_filename); Log::Info("Start prediction for data %s without label", data_filename);
} }
std::function<double(const std::vector<std::pair<int, double>>&)> predict_fun; std::function<std::string(const std::vector<std::pair<int, double>>&)> predict_fun;
if (is_simgoid_) { if (predict_leaf_index) {
predict_fun = [this](const std::vector<std::pair<int, double>>& features) { predict_fun = [this](const std::vector<std::pair<int, double>>& features){
return PredictOneLine(features); std::vector<int> predicted_leaf_index = PredictLeafIndexOneLine(features);
}; std::stringstream result_ss;
} else { for (size_t i = 0; i < predicted_leaf_index.size(); ++i){
predict_fun = [this](const std::vector<std::pair<int, double>>& features) { if (i > 0) {
return PredictRawOneLine(features); result_ss << '\t';
}
result_ss << predicted_leaf_index[i];
}
return result_ss.str();
}; };
} }
else {
if (is_simgoid_) {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){
return std::to_string(PredictOneLine(features));
};
}
else {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){
return std::to_string(PredictRawOneLine(features));
};
}
}
std::function<void(data_size_t, const std::vector<std::string>&)> process_fun = std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
[this, &parser_fun, &predict_fun, &result_file] [this, &parser_fun, &predict_fun, &result_file]
(data_size_t, const std::vector<std::string>& lines) { (data_size_t, const std::vector<std::string>& lines) {
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, double>> oneline_features;
std::vector<double> pred_result(lines.size(), 0.0f); std::vector<std::string> pred_result(lines.size(), "");
#pragma omp parallel for schedule(static) private(oneline_features) #pragma omp parallel for schedule(static) private(oneline_features)
for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); i++) { for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); i++) {
oneline_features.clear(); oneline_features.clear();
...@@ -158,10 +194,9 @@ public: ...@@ -158,10 +194,9 @@ public:
} }
for (size_t i = 0; i < pred_result.size(); ++i) { for (size_t i = 0; i < pred_result.size(); ++i) {
fprintf(result_file, "%f\n", pred_result[i]); fprintf(result_file, "%s\n", pred_result[i].c_str());
} }
}; };
TextReader<data_size_t> predict_data_reader(data_filename); TextReader<data_size_t> predict_data_reader(data_filename);
predict_data_reader.ReadAllAndProcessParallel(process_fun); predict_data_reader.ReadAllAndProcessParallel(process_fun);
...@@ -180,6 +215,8 @@ private: ...@@ -180,6 +215,8 @@ private:
bool is_simgoid_; bool is_simgoid_;
/*! \brief Number of threads */ /*! \brief Number of threads */
int num_threads_; int num_threads_;
/*! \brief True if output leaf index instead of prediction score */
bool predict_leaf_index;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -369,4 +369,12 @@ double GBDT::Predict(const double* value) const { ...@@ -369,4 +369,12 @@ double GBDT::Predict(const double* value) const {
return ret; return ret;
} }
std::vector<int> GBDT::PredictLeafIndex(const double* value) const {
std::vector<int> ret;
for (size_t i = 0; i < models_.size(); ++i) {
ret.push_back(models_[i]->PredictLeafIndex(value));
}
return ret;
}
} // namespace LightGBM } // namespace LightGBM
...@@ -59,6 +59,14 @@ public: ...@@ -59,6 +59,14 @@ public:
* \return Prediction result for this record * \return Prediction result for this record
*/ */
double Predict(const double * feature_values) const override; double Predict(const double * feature_values) const override;
/*!
* \brief Predtion for one record with leaf index
* \param feature_values Feature value on this record
* \return Predicted leaf index for this record
*/
std::vector<int> PredictLeafIndex(const double* value) const override;
/*! /*!
* \brief Serialize models by string * \brief Serialize models by string
* \return String output of tranined model * \return String output of tranined model
......
...@@ -14,6 +14,8 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para ...@@ -14,6 +14,8 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
// load main config types // load main config types
GetInt(params, "num_threads", &num_threads); GetInt(params, "num_threads", &num_threads);
GetTaskType(params); GetTaskType(params);
GetBool(params, "predict_leaf_index", &predict_leaf_index);
GetBoostingType(params); GetBoostingType(params);
GetObjectiveType(params); GetObjectiveType(params);
......
...@@ -94,7 +94,7 @@ void SerialTreeLearner::Init(const Dataset* train_data) { ...@@ -94,7 +94,7 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
break; break;
} }
} }
// initialize splits for leaf // initialize splits for leaf
smaller_leaf_splits_ = new LeafSplits(train_data_->num_features(), train_data_->num_data()); smaller_leaf_splits_ = new LeafSplits(train_data_->num_features(), train_data_->num_data());
larger_leaf_splits_ = new LeafSplits(train_data_->num_features(), train_data_->num_data()); larger_leaf_splits_ = new LeafSplits(train_data_->num_features(), train_data_->num_data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment