Commit 993bbd5f authored by cbecker's avatar cbecker Committed by Guolin Ke
Browse files

Add prediction early stopping (#550)

* Add early stopping for prediction

* Fix GBDT if-else prediction with early stopping

* Small C++ embelishments to early stopping API and functions

* Fix early stopping efficiency issue by creating a singleton for no early stopping

* Python improvements to early stopping API

* Add assertion check for binary and multiclass prediction score length

* Update vcxproj and vcxproj.filters with new early stopping files

* Remove inline from PredictRaw(), the linker was not able to find it otherwise
parent 2cca8283
...@@ -13,6 +13,7 @@ namespace LightGBM { ...@@ -13,6 +13,7 @@ namespace LightGBM {
class Dataset; class Dataset;
class ObjectiveFunction; class ObjectiveFunction;
class Metric; class Metric;
class PredictionEarlyStopInstance;
/*! /*!
* \brief The interface for Boosting * \brief The interface for Boosting
...@@ -116,15 +117,19 @@ public: ...@@ -116,15 +117,19 @@ public:
* \brief Prediction for one record, not sigmoid transform * \brief Prediction for one record, not sigmoid transform
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param output Prediction result for this record * \param output Prediction result for this record
* \param earlyStop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
*/ */
virtual void PredictRaw(const double* features, double* output) const = 0; virtual void PredictRaw(const double* features, double* output,
const PredictionEarlyStopInstance* earlyStop = nullptr) const = 0;
/*! /*!
* \brief Prediction for one record, sigmoid transformation will be used if needed * \brief Prediction for one record, sigmoid transformation will be used if needed
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param output Prediction result for this record * \param output Prediction result for this record
* \param earlyStop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
*/ */
virtual void Predict(const double* features, double* output) const = 0; virtual void Predict(const double* features, double* output,
const PredictionEarlyStopInstance* earlyStop = nullptr) const = 0;
/*! /*!
* \brief Prediction for one record with leaf index * \brief Prediction for one record with leaf index
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
typedef void* DatasetHandle; typedef void* DatasetHandle;
typedef void* BoosterHandle; typedef void* BoosterHandle;
typedef void* PredictionEarlyStoppingHandle;
#define C_API_DTYPE_FLOAT32 (0) #define C_API_DTYPE_FLOAT32 (0)
#define C_API_DTYPE_FLOAT64 (1) #define C_API_DTYPE_FLOAT64 (1)
...@@ -521,6 +522,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -521,6 +522,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
int data_has_header, int data_has_header,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
const char* result_filename); const char* result_filename);
/*! /*!
...@@ -560,6 +562,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle, ...@@ -560,6 +562,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
* C_API_PREDICT_RAW_SCORE: raw score * C_API_PREDICT_RAW_SCORE: raw score
* C_API_PREDICT_LEAF_INDEX: leaf index * C_API_PREDICT_LEAF_INDEX: leaf index
* \param num_iteration number of iteration for prediction, <= 0 means no limit * \param num_iteration number of iteration for prediction, <= 0 means no limit
* \param early_stop_handle early stopping to use for prediction. If null, no early stopping is applied
* \param out_len len of output result * \param out_len len of output result
* \param out_result used to set a pointer to array, should allocate memory before call this function * \param out_result used to set a pointer to array, should allocate memory before call this function
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
...@@ -575,6 +578,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -575,6 +578,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t num_col, int64_t num_col,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
...@@ -597,6 +601,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -597,6 +601,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
* C_API_PREDICT_RAW_SCORE: raw score * C_API_PREDICT_RAW_SCORE: raw score
* C_API_PREDICT_LEAF_INDEX: leaf index * C_API_PREDICT_LEAF_INDEX: leaf index
* \param num_iteration number of iteration for prediction, <= 0 means no limit * \param num_iteration number of iteration for prediction, <= 0 means no limit
* \param early_stop_handle early stopping to use for prediction. If null, no early stopping is applied
* \param out_len len of output result * \param out_len len of output result
* \param out_result used to set a pointer to array, should allocate memory before call this function * \param out_result used to set a pointer to array, should allocate memory before call this function
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
...@@ -612,6 +617,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -612,6 +617,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
int64_t num_row, int64_t num_row,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
...@@ -631,6 +637,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -631,6 +637,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
* C_API_PREDICT_RAW_SCORE: raw score * C_API_PREDICT_RAW_SCORE: raw score
* C_API_PREDICT_LEAF_INDEX: leaf index * C_API_PREDICT_LEAF_INDEX: leaf index
* \param num_iteration number of iteration for prediction, <= 0 means no limit * \param num_iteration number of iteration for prediction, <= 0 means no limit
* \param early_stop_handle early stopping to use for prediction. If null, no early stopping is applied
* \param out_len len of output result * \param out_len len of output result
* \param out_result used to set a pointer to array, should allocate memory before call this function * \param out_result used to set a pointer to array, should allocate memory before call this function
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
...@@ -643,6 +650,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -643,6 +650,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
...@@ -713,6 +721,25 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle, ...@@ -713,6 +721,25 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
int leaf_idx, int leaf_idx,
double val); double val);
/*!
* \brief create an new prediction early stopping instance that can be used to speed up prediction
* \param type early stopping type: "none", "multiclass" or "binary"
* \param round_period how often the classifier score is checked for the early stopping condition
* \param margin_threshold when the margin exceeds this value, early stopping kicks in and no more trees are evaluated
* \param out handle of created instance
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_PredictionEarlyStopInstanceCreate(const char* type,
int round_period,
double margin_threshold,
PredictionEarlyStoppingHandle* out);
/*!
\brief free prediction early stop instance
\return 0 when succeed
*/
LIGHTGBM_C_EXPORT int LGBM_PredictionEarlyStopInstanceFree(const PredictionEarlyStoppingHandle handle);
#if defined(_MSC_VER) #if defined(_MSC_VER)
// exception handle and error msg // exception handle and error msg
static char* LastErrorMsg() { static __declspec(thread) char err_msg[512] = "Everything is fine"; return err_msg; } static char* LastErrorMsg() { static __declspec(thread) char err_msg[512] = "Everything is fine"; return err_msg; }
......
#ifndef LIGHTGBM_PREDICTION_EARLY_STOP_H_
#define LIGHTGBM_PREDICTION_EARLY_STOP_H_
#include <functional>
#include <string>
#include <LightGBM/export.h>
namespace LightGBM
{
struct PredictionEarlyStopInstance
{
/// Callback function type for early stopping.
/// Takes current prediction and number of elements in prediction
/// @returns true if prediction should stop according to criterion
using FunctionType = std::function<bool(const double*, int)>;
FunctionType callbackFunction; // callback function itself
int roundPeriod; // call callbackFunction every `runPeriod` iterations
};
struct PredictionEarlyStopConfig
{
int roundPeriod;
double marginThreshold;
};
/// Create an early stopping algorithm of type `type`, with given roundPeriod and margin threshold
LIGHTGBM_EXPORT PredictionEarlyStopInstance createPredictionEarlyStopInstance(const std::string& type,
const PredictionEarlyStopConfig& config);
} // namespace LightGBM
#endif // LIGHTGBM_PREDICTION_EARLY_STOP_H_
...@@ -6,7 +6,7 @@ Contributors: https://github.com/Microsoft/LightGBM/graphs/contributors ...@@ -6,7 +6,7 @@ Contributors: https://github.com/Microsoft/LightGBM/graphs/contributors
from __future__ import absolute_import from __future__ import absolute_import
from .basic import Booster, Dataset from .basic import Booster, Dataset, PredictionEarlyStopInstance
from .callback import (early_stopping, print_evaluation, record_evaluation, from .callback import (early_stopping, print_evaluation, record_evaluation,
reset_parameter) reset_parameter)
from .engine import cv, train from .engine import cv, train
...@@ -23,7 +23,7 @@ except ImportError: ...@@ -23,7 +23,7 @@ except ImportError:
__version__ = 0.2 __version__ = 0.2
__all__ = ['Dataset', 'Booster', __all__ = ['Dataset', 'Booster', 'PredictionEarlyStopInstance',
'train', 'cv', 'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker', 'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
......
...@@ -296,7 +296,7 @@ class _InnerPredictor(object): ...@@ -296,7 +296,7 @@ class _InnerPredictor(object):
Only used for prediction, usually used for continued-train Only used for prediction, usually used for continued-train
Note: Can convert from Booster, but cannot convert to Booster Note: Can convert from Booster, but cannot convert to Booster
""" """
def __init__(self, model_file=None, booster_handle=None): def __init__(self, model_file=None, booster_handle=None, early_stop_instance=None):
"""Initialize the _InnerPredictor. Not expose to user """Initialize the _InnerPredictor. Not expose to user
Parameters Parameters
...@@ -305,6 +305,8 @@ class _InnerPredictor(object): ...@@ -305,6 +305,8 @@ class _InnerPredictor(object):
Path to the model file. Path to the model file.
booster_handle : Handle of Booster booster_handle : Handle of Booster
use handle to init use handle to init
early_stop_instance: object of type PredictionEarlyStopInstance
If None, no early stopping is applied
""" """
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
self.__is_manage_handle = True self.__is_manage_handle = True
...@@ -339,6 +341,11 @@ class _InnerPredictor(object): ...@@ -339,6 +341,11 @@ class _InnerPredictor(object):
else: else:
raise TypeError('Need Model file or Booster handle to create a predictor') raise TypeError('Need Model file or Booster handle to create a predictor')
if early_stop_instance is None:
self.early_stop_instance = PredictionEarlyStopInstance("none")
else:
self.early_stop_instance = early_stop_instance
def __del__(self): def __del__(self):
if self.__is_manage_handle: if self.__is_manage_handle:
_safe_call(_LIB.LGBM_BoosterFree(self.handle)) _safe_call(_LIB.LGBM_BoosterFree(self.handle))
...@@ -385,6 +392,7 @@ class _InnerPredictor(object): ...@@ -385,6 +392,7 @@ class _InnerPredictor(object):
int_data_has_header = 1 if data_has_header else 0 int_data_has_header = 1 if data_has_header else 0
if num_iteration > self.num_total_iteration: if num_iteration > self.num_total_iteration:
num_iteration = self.num_total_iteration num_iteration = self.num_total_iteration
if isinstance(data, string_type): if isinstance(data, string_type):
with _temp_file() as f: with _temp_file() as f:
_safe_call(_LIB.LGBM_BoosterPredictForFile( _safe_call(_LIB.LGBM_BoosterPredictForFile(
...@@ -393,6 +401,7 @@ class _InnerPredictor(object): ...@@ -393,6 +401,7 @@ class _InnerPredictor(object):
ctypes.c_int(int_data_has_header), ctypes.c_int(int_data_has_header),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
self.early_stop_instance.handle,
c_str(f.name))) c_str(f.name)))
lines = f.readlines() lines = f.readlines()
nrow = len(lines) nrow = len(lines)
...@@ -409,7 +418,7 @@ class _InnerPredictor(object): ...@@ -409,7 +418,7 @@ class _InnerPredictor(object):
predict_type) predict_type)
elif isinstance(data, DataFrame): elif isinstance(data, DataFrame):
preds, nrow = self.__pred_for_np2d(data.values, num_iteration, preds, nrow = self.__pred_for_np2d(data.values, num_iteration,
predict_type) predict_type, early_stop_instance_handle)
else: else:
try: try:
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
...@@ -466,6 +475,7 @@ class _InnerPredictor(object): ...@@ -466,6 +475,7 @@ class _InnerPredictor(object):
ctypes.c_int(C_API_IS_ROW_MAJOR), ctypes.c_int(C_API_IS_ROW_MAJOR),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
self.early_stop_instance.handle,
ctypes.byref(out_num_preds), ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if n_preds != out_num_preds.value: if n_preds != out_num_preds.value:
...@@ -496,6 +506,7 @@ class _InnerPredictor(object): ...@@ -496,6 +506,7 @@ class _InnerPredictor(object):
ctypes.c_int64(csr.shape[1]), ctypes.c_int64(csr.shape[1]),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
self.early_stop_instance.handle,
ctypes.byref(out_num_preds), ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if n_preds != out_num_preds.value: if n_preds != out_num_preds.value:
...@@ -526,6 +537,7 @@ class _InnerPredictor(object): ...@@ -526,6 +537,7 @@ class _InnerPredictor(object):
ctypes.c_int64(csc.shape[0]), ctypes.c_int64(csc.shape[0]),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
self.early_stop_instance.handle,
ctypes.byref(out_num_preds), ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if n_preds != out_num_preds.value: if n_preds != out_num_preds.value:
...@@ -1568,7 +1580,8 @@ class Booster(object): ...@@ -1568,7 +1580,8 @@ class Booster(object):
ptr_string_buffer)) ptr_string_buffer))
return json.loads(string_buffer.value.decode()) return json.loads(string_buffer.value.decode())
def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True): def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True,
early_stop_instance=None):
""" """
Predict logic Predict logic
...@@ -1587,19 +1600,21 @@ class Booster(object): ...@@ -1587,19 +1600,21 @@ class Booster(object):
Used for txt data Used for txt data
is_reshape : bool is_reshape : bool
Reshape to (nrow, ncol) if true Reshape to (nrow, ncol) if true
early_stop_instance: object of type PredictionEarlyStopInstance.
If None, no early stopping is applied
Returns Returns
------- -------
Prediction result Prediction result
""" """
predictor = self._to_predictor() predictor = self._to_predictor(early_stop_instance)
if num_iteration <= 0: if num_iteration <= 0:
num_iteration = self.best_iteration num_iteration = self.best_iteration
return predictor.predict(data, num_iteration, raw_score, pred_leaf, data_has_header, is_reshape) return predictor.predict(data, num_iteration, raw_score, pred_leaf, data_has_header, is_reshape)
def _to_predictor(self): def _to_predictor(self, early_stop_instance=None):
"""Convert to predictor""" """Convert to predictor"""
predictor = _InnerPredictor(booster_handle=self.handle) predictor = _InnerPredictor(booster_handle=self.handle, early_stop_instance=early_stop_instance)
predictor.pandas_categorical = self.pandas_categorical predictor.pandas_categorical = self.pandas_categorical
return predictor return predictor
...@@ -1785,3 +1800,35 @@ class Booster(object): ...@@ -1785,3 +1800,35 @@ class Booster(object):
self.__attr[key] = value self.__attr[key] = value
else: else:
self.__attr.pop(key, None) self.__attr.pop(key, None)
class PredictionEarlyStopInstance(object):
""""PredictionEarlyStopInstance in LightGBM."""
def __init__(self, early_stop_type="none", round_period=20, margin_threshold=1.5):
"""
Create an early stopping object
Parameters
----------
early_stop_type: string
None, "none", "binary" or "multiclass". Regression is not supported.
round_period : int
The score will be checked every round_period to check if the early stopping criteria is met
margin_threshold : double
Early stopping will kick in when the margin is greater than margin_threshold
"""
self.handle = ctypes.c_void_p(0)
self.__attr = {}
if early_stop_type is None:
early_stop_type = "none"
_safe_call(_LIB.LGBM_PredictionEarlyStopInstanceCreate(
c_str(early_stop_type),
ctypes.c_int(round_period),
ctypes.c_double(margin_threshold),
ctypes.byref(self.handle)))
def __del__(self):
if self.handle is not None:
_safe_call(_LIB.LGBM_PredictionEarlyStopInstanceFree(self.handle))
...@@ -31,7 +31,8 @@ public: ...@@ -31,7 +31,8 @@ public:
* \param is_predict_leaf_index True if output leaf index instead of prediction score * \param is_predict_leaf_index True if output leaf index instead of prediction score
*/ */
Predictor(Boosting* boosting, int num_iteration, Predictor(Boosting* boosting, int num_iteration,
bool is_raw_score, bool is_predict_leaf_index) { bool is_raw_score, bool is_predict_leaf_index,
const PredictionEarlyStopInstance* earlyStop = nullptr) {
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
{ {
...@@ -54,17 +55,17 @@ public: ...@@ -54,17 +55,17 @@ public:
} else { } else {
if (is_raw_score) { if (is_raw_score) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this, earlyStop](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features); CopyToPredictBuffer(predict_buf_[tid].data(), features);
boosting_->PredictRaw(predict_buf_[tid].data(), output); boosting_->PredictRaw(predict_buf_[tid].data(), output, earlyStop);
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features); ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
}; };
} else { } else {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this, earlyStop](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features); CopyToPredictBuffer(predict_buf_[tid].data(), features);
boosting_->Predict(predict_buf_[tid].data(), output); boosting_->Predict(predict_buf_[tid].data(), output, earlyStop);
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features); ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
}; };
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <LightGBM/objective_function.h> #include <LightGBM/objective_function.h>
#include <LightGBM/metric.h> #include <LightGBM/metric.h>
#include <LightGBM/prediction_early_stop.h>
#include <ctime> #include <ctime>
...@@ -703,10 +704,10 @@ std::string GBDT::ModelToIfElse(int num_iteration) const { ...@@ -703,10 +704,10 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << "#include \"gbdt.h\"" << std::endl; str_buf << "#include \"gbdt.h\"" << std::endl;
str_buf << "#include <LightGBM/utils/openmp_wrapper.h>" << std::endl;
str_buf << "#include <LightGBM/utils/common.h>" << std::endl; str_buf << "#include <LightGBM/utils/common.h>" << std::endl;
str_buf << "#include <LightGBM/objective_function.h>" << std::endl; str_buf << "#include <LightGBM/objective_function.h>" << std::endl;
str_buf << "#include <LightGBM/metric.h>" << std::endl; str_buf << "#include <LightGBM/metric.h>" << std::endl;
str_buf << "#include <LightGBM/prediction_early_stop.h>" << std::endl;
str_buf << "#include <ctime>" << std::endl; str_buf << "#include <ctime>" << std::endl;
str_buf << "#include <sstream>" << std::endl; str_buf << "#include <sstream>" << std::endl;
str_buf << "#include <chrono>" << std::endl; str_buf << "#include <chrono>" << std::endl;
...@@ -737,32 +738,32 @@ std::string GBDT::ModelToIfElse(int num_iteration) const { ...@@ -737,32 +738,32 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
std::stringstream pred_str_buf; std::stringstream pred_str_buf;
pred_str_buf << "\t" << "if (num_threads_ <= num_tree_per_iteration_) {" << std::endl; pred_str_buf << "\t" << "const auto noEarlyStop = createPredictionEarlyStopInstance(\"none\", PredictionEarlyStopConfig());" << std::endl;
pred_str_buf << "\t\t" << "#pragma omp parallel for schedule(static)" << std::endl; pred_str_buf << "\t" << "if (earlyStop == nullptr) {" << std::endl;
pred_str_buf << "\t\t" << "earlyStop = &noEarlyStop;" << std::endl;
pred_str_buf << "\t" << "}" << std::endl;
pred_str_buf << "\t" << "int earlyStopRoundCounter = 0;" << std::endl;
pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl; pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
pred_str_buf << "\t\t\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl; pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t\t" << "}" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl; pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t" << "} else {" << std::endl; pred_str_buf << "\t\t" << "++earlyStopRoundCounter;" << std::endl;
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl; pred_str_buf << "\t\t" << "if (earlyStop->roundPeriod == earlyStopRoundCounter) {" << std::endl;
pred_str_buf << "\t\t\t" << "double t = 0.0f;" << std::endl; pred_str_buf << "\t\t\t" << "if (earlyStop->callbackFunction(output, num_tree_per_iteration_))" << std::endl;
pred_str_buf << "\t\t\t" << "#pragma omp parallel for schedule(static) reduction(+:t)" << std::endl; pred_str_buf << "\t\t\t\t" << "return;" << std::endl;
pred_str_buf << "\t\t\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl; pred_str_buf << "\t\t\t" << "earlyStopRoundCounter = 0;" << std::endl;
pred_str_buf << "\t\t\t\t" << "t += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t\t" << "}" << std::endl;
pred_str_buf << "\t\t\t" << "output[k] = t;" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl; pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t" << "}" << std::endl; pred_str_buf << "\t" << "}" << std::endl;
str_buf << "void GBDT::PredictRaw(const double* features, double *output) const {" << std::endl; str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* earlyStop) const {" << std::endl;
str_buf << pred_str_buf.str(); str_buf << pred_str_buf.str();
str_buf << "}" << std::endl; str_buf << "}" << std::endl;
str_buf << std::endl; str_buf << std::endl;
// Predict // Predict
str_buf << "void GBDT::Predict(const double* features, double *output) const {" << std::endl; str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* earlyStop) const {" << std::endl;
str_buf << pred_str_buf.str(); str_buf << "\t" << "PredictRaw(features, output, earlyStop);" << std::endl;
str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl; str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl;
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl; str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
str_buf << "\t" << "}" << std::endl; str_buf << "\t" << "}" << std::endl;
......
...@@ -136,9 +136,11 @@ public: ...@@ -136,9 +136,11 @@ public:
return num_preb_in_one_row; return num_preb_in_one_row;
} }
void PredictRaw(const double* features, double* output) const override; void PredictRaw(const double* features, double* output,
const PredictionEarlyStopInstance* earlyStop = nullptr) const override;
void Predict(const double* features, double* output) const override; void Predict(const double* features, double* output,
const PredictionEarlyStopInstance* earlyStop) const override;
void PredictLeafIndex(const double* features, double* output) const override; void PredictLeafIndex(const double* features, double* output) const override;
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <LightGBM/objective_function.h> #include <LightGBM/objective_function.h>
#include <LightGBM/metric.h> #include <LightGBM/metric.h>
#include <LightGBM/prediction_early_stop.h>
#include <ctime> #include <ctime>
...@@ -15,46 +16,41 @@ ...@@ -15,46 +16,41 @@
#include <vector> #include <vector>
#include <utility> #include <utility>
namespace
{
/// Singleton used when earlyStop is nullptr in PredictRaw()
const auto noEarlyStop = LightGBM::createPredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
}
namespace LightGBM { namespace LightGBM {
void GBDT::PredictRaw(const double* features, double* output) const { void GBDT::PredictRaw(const double* features, double* output, const PredictionEarlyStopInstance* earlyStop) const {
if (num_threads_ <= num_tree_per_iteration_) { if (earlyStop == nullptr)
#pragma omp parallel for schedule(static) {
for (int k = 0; k < num_tree_per_iteration_; ++k) { earlyStop = &noEarlyStop;
for (int i = 0; i < num_iteration_for_pred_; ++i) {
output[k] += models_[i * num_tree_per_iteration_ + k]->Predict(features);
}
}
} else {
for (int k = 0; k < num_tree_per_iteration_; ++k) {
double t = 0.0f;
#pragma omp parallel for schedule(static) reduction(+:t)
for (int i = 0; i < num_iteration_for_pred_; ++i) {
t += models_[i * num_tree_per_iteration_ + k]->Predict(features);
}
output[k] = t;
}
} }
}
void GBDT::Predict(const double* features, double* output) const { int earlyStopRoundCounter = 0;
if (num_threads_ <= num_tree_per_iteration_) {
#pragma omp parallel for schedule(static)
for (int k = 0; k < num_tree_per_iteration_; ++k) {
for (int i = 0; i < num_iteration_for_pred_; ++i) { for (int i = 0; i < num_iteration_for_pred_; ++i) {
// predict all the trees for one iteration
for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] += models_[i * num_tree_per_iteration_ + k]->Predict(features); output[k] += models_[i * num_tree_per_iteration_ + k]->Predict(features);
} }
// check early stopping
++earlyStopRoundCounter;
if (earlyStop->roundPeriod == earlyStopRoundCounter) {
if (earlyStop->callbackFunction(output, num_tree_per_iteration_)) {
return;
} }
} else { earlyStopRoundCounter = 0;
for (int k = 0; k < num_tree_per_iteration_; ++k) {
double t = 0.0f;
#pragma omp parallel for schedule(static) reduction(+:t)
for (int i = 0; i < num_iteration_for_pred_; ++i) {
t += models_[i * num_tree_per_iteration_ + k]->Predict(features);
}
output[k] = t;
} }
} }
}
void GBDT::Predict(const double* features, double* output, const PredictionEarlyStopInstance* earlyStop) const {
PredictRaw(features, output, earlyStop);
if (objective_function_ != nullptr) { if (objective_function_ != nullptr) {
objective_function_->ConvertOutput(output, output); objective_function_->ConvertOutput(output, output);
} }
......
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/utils/log.h>
using namespace LightGBM;
#include <algorithm>
#include <vector>
#include <cmath>
#include <limits>
namespace
{
PredictionEarlyStopInstance createNone(const PredictionEarlyStopConfig&)
{
return PredictionEarlyStopInstance{
[](const double*, int)
{
return false;
},
std::numeric_limits<int>::max() // make sure the lambda is almost never called
};
}
PredictionEarlyStopInstance createMulticlass(const PredictionEarlyStopConfig& config)
{
// marginThreshold will be captured by value
const double marginThreshold = config.marginThreshold;
return PredictionEarlyStopInstance{
[marginThreshold](const double* pred, int sz)
{
if(sz < 2) {
Log::Fatal("Multiclass early stopping needs predictions to be of length two or larger");
}
// copy and sort
std::vector<double> votes(static_cast<size_t>(sz));
for (int i=0; i < sz; ++i) {
votes[i] = pred[i];
}
std::partial_sort(votes.begin(), votes.begin() + 2, votes.end(), std::greater<double>());
const auto margin = votes[0] - votes[1];
if (margin > marginThreshold) {
return true;
}
return false;
},
config.roundPeriod
};
}
PredictionEarlyStopInstance createBinary(const PredictionEarlyStopConfig& config)
{
// marginThreshold will be captured by value
const double marginThreshold = config.marginThreshold;
return PredictionEarlyStopInstance{
[marginThreshold](const double* pred, int sz)
{
if(sz != 1) {
Log::Fatal("Binary early stopping needs predictions to be of length one");
}
const auto margin = 2.0 * fabs(pred[0]);
if (margin > marginThreshold) {
return true;
}
return false;
},
config.roundPeriod
};
}
}
namespace LightGBM
{
PredictionEarlyStopInstance createPredictionEarlyStopInstance(const std::string& type,
const PredictionEarlyStopConfig& config)
{
if (type == "none")
{
return createNone(config);
}
else if (type == "multiclass")
{
return createMulticlass(config);
}
else if (type == "binary")
{
return createBinary(config);
}
else
{
throw std::runtime_error("Unknown early stopping type: " + type);
}
}
}
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <LightGBM/objective_function.h> #include <LightGBM/objective_function.h>
#include <LightGBM/metric.h> #include <LightGBM/metric.h>
#include <LightGBM/config.h> #include <LightGBM/config.h>
#include <LightGBM/prediction_early_stop.h>
#include <cstdio> #include <cstdio>
#include <vector> #include <vector>
...@@ -162,6 +163,7 @@ public: ...@@ -162,6 +163,7 @@ public:
void Predict(int num_iteration, int predict_type, int nrow, void Predict(int num_iteration, int predict_type, int nrow,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun, std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const PredictionEarlyStoppingHandle early_stop_handle,
double* out_result, int64_t* out_len) { double* out_result, int64_t* out_len) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false; bool is_predict_leaf = false;
...@@ -173,7 +175,8 @@ public: ...@@ -173,7 +175,8 @@ public:
} else { } else {
is_raw_score = false; is_raw_score = false;
} }
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf); Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf,
reinterpret_cast<const PredictionEarlyStopInstance*>(early_stop_handle));
int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf); int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf);
auto pred_fun = predictor.GetPredictFunction(); auto pred_fun = predictor.GetPredictFunction();
auto pred_wrt_ptr = out_result; auto pred_wrt_ptr = out_result;
...@@ -186,7 +189,8 @@ public: ...@@ -186,7 +189,8 @@ public:
} }
void Predict(int num_iteration, int predict_type, const char* data_filename, void Predict(int num_iteration, int predict_type, const char* data_filename,
int data_has_header, const char* result_filename) { int data_has_header, const PredictionEarlyStoppingHandle early_stop_handle,
const char* result_filename) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false; bool is_predict_leaf = false;
bool is_raw_score = false; bool is_raw_score = false;
...@@ -197,7 +201,8 @@ public: ...@@ -197,7 +201,8 @@ public:
} else { } else {
is_raw_score = false; is_raw_score = false;
} }
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf); Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf,
reinterpret_cast<const PredictionEarlyStopInstance*>(early_stop_handle));
bool bool_data_has_header = data_has_header > 0 ? true : false; bool bool_data_has_header = data_has_header > 0 ? true : false;
predictor.Predict(data_filename, result_filename, bool_data_has_header); predictor.Predict(data_filename, result_filename, bool_data_has_header);
} }
...@@ -950,10 +955,12 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -950,10 +955,12 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle,
int data_has_header, int data_has_header,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
const char* result_filename) { const char* result_filename) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header, result_filename); ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
early_stop_handle, result_filename);
API_END(); API_END();
} }
...@@ -980,13 +987,15 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -980,13 +987,15 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t, int64_t,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int nrow = static_cast<int>(nindptr - 1); int nrow = static_cast<int>(nindptr - 1);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, out_result, out_len); ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
early_stop_handle, out_result, out_len);
API_END(); API_END();
} }
...@@ -1001,6 +1010,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -1001,6 +1010,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
int64_t num_row, int64_t num_row,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
...@@ -1021,7 +1031,8 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -1021,7 +1031,8 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
} }
return one_row; return one_row;
}; };
ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, out_result, out_len); ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, early_stop_handle,
out_result, out_len);
API_END(); API_END();
} }
...@@ -1033,12 +1044,14 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -1033,12 +1044,14 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, out_result, out_len); ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
early_stop_handle, out_result, out_len);
API_END(); API_END();
} }
...@@ -1101,6 +1114,31 @@ int LGBM_BoosterSetLeafValue(BoosterHandle handle, ...@@ -1101,6 +1114,31 @@ int LGBM_BoosterSetLeafValue(BoosterHandle handle,
API_END(); API_END();
} }
int LGBM_PredictionEarlyStopInstanceCreate(const char* type,
int round_period,
double margin_threshold,
PredictionEarlyStoppingHandle* out)
{
API_BEGIN();
PredictionEarlyStopConfig config;
config.marginThreshold = margin_threshold;
config.roundPeriod = round_period;
auto earlyStop = createPredictionEarlyStopInstance(type, config);
// create new by copying
*out = new PredictionEarlyStopInstance(earlyStop);
API_END();
}
int LGBM_PredictionEarlyStopInstanceFree(const PredictionEarlyStoppingHandle handle)
{
API_BEGIN();
delete reinterpret_cast<const PredictionEarlyStopInstance*>(handle);
API_END();
}
// ---- start of some help functions // ---- start of some help functions
std::function<std::vector<double>(int row_idx)> std::function<std::vector<double>(int row_idx)>
......
...@@ -228,7 +228,8 @@ def test_booster(): ...@@ -228,7 +228,8 @@ def test_booster():
1, 1,
1, 1,
50, 50,
ctypes.c_void_p(),
ctypes.byref(num_preb), ctypes.byref(num_preb),
preb.ctypes.data_as(ctypes.POINTER(ctypes.c_double))) preb.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
LIB.LGBM_BoosterPredictForFile(booster2, c_str(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/binary_classification/binary.test')), 0, 0, 50, c_str('preb.txt')) LIB.LGBM_BoosterPredictForFile(booster2, c_str(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/binary_classification/binary.test')), 0, 0, 50, ctypes.c_void_p(), c_str('preb.txt'))
LIB.LGBM_BoosterFree(booster2) LIB.LGBM_BoosterFree(booster2)
...@@ -44,6 +44,7 @@ class TestBasic(unittest.TestCase): ...@@ -44,6 +44,7 @@ class TestBasic(unittest.TestCase):
self.assertEqual(len(pred_from_matr), len(pred_from_file)) self.assertEqual(len(pred_from_matr), len(pred_from_file))
for preds in zip(pred_from_matr, pred_from_file): for preds in zip(pred_from_matr, pred_from_file):
self.assertAlmostEqual(*preds, places=15) self.assertAlmostEqual(*preds, places=15)
# check saved model persistence # check saved model persistence
bst = lgb.Booster(params, model_file="model.txt") bst = lgb.Booster(params, model_file="model.txt")
pred_from_model_file = bst.predict(X_test) pred_from_model_file = bst.predict(X_test)
...@@ -51,5 +52,14 @@ class TestBasic(unittest.TestCase): ...@@ -51,5 +52,14 @@ class TestBasic(unittest.TestCase):
for preds in zip(pred_from_matr, pred_from_model_file): for preds in zip(pred_from_matr, pred_from_model_file):
# we need to check the consistency of model file here, so test for exact equal # we need to check the consistency of model file here, so test for exact equal
self.assertEqual(*preds) self.assertEqual(*preds)
# check early stopping is working. Make it stop very early, so the scores should be very close to zero
estop = lgb.PredictionEarlyStopInstance("binary", round_period=5, margin_threshold=1.5)
pred_early_stopping = bst.predict(X_test, early_stop_instance=estop)
self.assertEqual(len(pred_from_matr), len(pred_early_stopping))
for preds in zip(pred_early_stopping, pred_from_matr):
# scores likely to be different, but prediction should still be the same
self.assertEqual(preds[0] > 0, preds[1] > 0)
# check pmml # check pmml
subprocess.call(['python', os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../pmml/pmml.py'), 'model.txt']) subprocess.call(['python', os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../pmml/pmml.py'), 'model.txt'])
...@@ -90,6 +90,33 @@ class TestEngine(unittest.TestCase): ...@@ -90,6 +90,33 @@ class TestEngine(unittest.TestCase):
self.assertLess(ret, 0.2) self.assertLess(ret, 0.2)
self.assertAlmostEqual(evals_result['valid_0']['multi_logloss'][-1], ret, places=5) self.assertAlmostEqual(evals_result['valid_0']['multi_logloss'][-1], ret, places=5)
def test_multiclass_prediction_early_stopping(self):
X, y = load_digits(10, True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
params = {
'objective': 'multiclass',
'metric': 'multi_logloss',
'num_class': 10,
'verbose': -1
}
lgb_train = lgb.Dataset(X_train, y_train, params=params)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train, params=params)
evals_result = {}
gbm = lgb.train(params, lgb_train,
num_boost_round=50,
valid_sets=lgb_eval,
verbose_eval=False,
evals_result=evals_result)
estop = lgb.PredictionEarlyStopInstance("multiclass", round_period=5, margin_threshold=1.5)
ret = multi_logloss(y_test, gbm.predict(X_test, early_stop_instance=estop))
self.assertLess(ret, 0.8)
self.assertGreater(ret, 0.5) # loss will be higher than when evaluating the full model
estop = lgb.PredictionEarlyStopInstance("multiclass", round_period=5, margin_threshold=5.5)
ret = multi_logloss(y_test, gbm.predict(X_test, early_stop_instance=estop))
self.assertLess(ret, 0.2)
def test_early_stopping(self): def test_early_stopping(self):
X, y = load_breast_cancer(True) X, y = load_breast_cancer(True)
params = { params = {
......
...@@ -201,6 +201,7 @@ ...@@ -201,6 +201,7 @@
<ClInclude Include="..\include\LightGBM\metric.h" /> <ClInclude Include="..\include\LightGBM\metric.h" />
<ClInclude Include="..\include\LightGBM\network.h" /> <ClInclude Include="..\include\LightGBM\network.h" />
<ClInclude Include="..\include\LightGBM\objective_function.h" /> <ClInclude Include="..\include\LightGBM\objective_function.h" />
<ClInclude Include="..\include\LightGBM\prediction_early_stop.h" />
<ClInclude Include="..\include\LightGBM\tree.h" /> <ClInclude Include="..\include\LightGBM\tree.h" />
<ClInclude Include="..\include\LightGBM\tree_learner.h" /> <ClInclude Include="..\include\LightGBM\tree_learner.h" />
<ClInclude Include="..\include\LightGBM\utils\array_args.h" /> <ClInclude Include="..\include\LightGBM\utils\array_args.h" />
...@@ -244,6 +245,7 @@ ...@@ -244,6 +245,7 @@
<ClCompile Include="..\src\boosting\boosting.cpp" /> <ClCompile Include="..\src\boosting\boosting.cpp" />
<ClCompile Include="..\src\boosting\gbdt.cpp" /> <ClCompile Include="..\src\boosting\gbdt.cpp" />
<ClCompile Include="..\src\boosting\gbdt_prediction.cpp" /> <ClCompile Include="..\src\boosting\gbdt_prediction.cpp" />
<ClCompile Include="..\src\boosting\prediction_early_stop.cpp" />
<ClCompile Include="..\src\c_api.cpp" /> <ClCompile Include="..\src\c_api.cpp" />
<ClCompile Include="..\src\io\bin.cpp" /> <ClCompile Include="..\src\io\bin.cpp" />
<ClCompile Include="..\src\io\config.cpp" /> <ClCompile Include="..\src\io\config.cpp" />
......
...@@ -126,6 +126,9 @@ ...@@ -126,6 +126,9 @@
<ClInclude Include="..\include\LightGBM\objective_function.h"> <ClInclude Include="..\include\LightGBM\objective_function.h">
<Filter>include\LightGBM</Filter> <Filter>include\LightGBM</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="..\include\LightGBM\prediction_early_stop.h">
<Filter>include\LightGBM</Filter>
</ClInclude>
<ClInclude Include="..\include\LightGBM\tree.h"> <ClInclude Include="..\include\LightGBM\tree.h">
<Filter>include\LightGBM</Filter> <Filter>include\LightGBM</Filter>
</ClInclude> </ClInclude>
...@@ -260,5 +263,9 @@ ...@@ -260,5 +263,9 @@
<ClCompile Include="..\src\boosting\gbdt_prediction.cpp"> <ClCompile Include="..\src\boosting\gbdt_prediction.cpp">
<Filter>src\boosting</Filter> <Filter>src\boosting</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="..\src\boosting\prediction_early_stop.cpp">
<Filter>src\boosting</Filter>
</ClCompile>
</ItemGroup> </ItemGroup>
</Project> </Project>
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment