Commit 6d4c7b03 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Support early stopping of prediction in CLI (#565)

* fix multi-threading.

* fix name style.

* support in CLI version.

* remove warnings.

* Not default parameters.

* fix if...else... .

* fix bug.

* fix warning.

* refine c_api.

* fix R-package.

* fix R's warning.

* fix tests.

* fix pep8 .
parent e04a8bb4
...@@ -54,7 +54,7 @@ if(USE_GPU) ...@@ -54,7 +54,7 @@ if(USE_GPU)
endif(USE_GPU) endif(USE_GPU)
if(UNIX OR MINGW OR CYGWIN) if(UNIX OR MINGW OR CYGWIN)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -O3 -Wextra -Wall -std=c++11 -Wno-ignored-attributes") SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -pthread -O3 -Wextra -Wall -Wno-ignored-attributes -Wno-unknown-pragmas")
endif() endif()
if(MSVC) if(MSVC)
......
...@@ -391,7 +391,7 @@ Booster <- R6Class( ...@@ -391,7 +391,7 @@ Booster <- R6Class(
rawscore = FALSE, rawscore = FALSE,
predleaf = FALSE, predleaf = FALSE,
header = FALSE, header = FALSE,
reshape = FALSE) { reshape = FALSE, ...) {
# Check if number of iteration is non existent # Check if number of iteration is non existent
if (is.null(num_iteration)) { if (is.null(num_iteration)) {
...@@ -399,7 +399,7 @@ Booster <- R6Class( ...@@ -399,7 +399,7 @@ Booster <- R6Class(
} }
# Predict on new data # Predict on new data
predictor <- Predictor$new(private$handle) predictor <- Predictor$new(private$handle, ...)
predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape) predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape)
}, },
...@@ -645,7 +645,7 @@ predict.lgb.Booster <- function(object, data, ...@@ -645,7 +645,7 @@ predict.lgb.Booster <- function(object, data,
rawscore = FALSE, rawscore = FALSE,
predleaf = FALSE, predleaf = FALSE,
header = FALSE, header = FALSE,
reshape = FALSE) { reshape = FALSE, ...) {
# Check booster existence # Check booster existence
if (!lgb.is.Booster(object)) { if (!lgb.is.Booster(object)) {
...@@ -658,7 +658,7 @@ predict.lgb.Booster <- function(object, data, ...@@ -658,7 +658,7 @@ predict.lgb.Booster <- function(object, data,
rawscore, rawscore,
predleaf, predleaf,
header, header,
reshape) reshape, ...)
} }
#' Load LightGBM model #' Load LightGBM model
......
...@@ -18,8 +18,9 @@ Predictor <- R6Class( ...@@ -18,8 +18,9 @@ Predictor <- R6Class(
}, },
# Initialize will create a starter model # Initialize will create a starter model
initialize = function(modelfile) { initialize = function(modelfile, ...) {
params <- list(...)
private$params <- lgb.params2str(params)
# Create new lgb handle # Create new lgb handle
handle <- lgb.new.handle() handle <- lgb.new.handle()
...@@ -86,6 +87,7 @@ Predictor <- R6Class( ...@@ -86,6 +87,7 @@ Predictor <- R6Class(
as.integer(rawscore), as.integer(rawscore),
as.integer(predleaf), as.integer(predleaf),
as.integer(num_iteration), as.integer(num_iteration),
private$params,
lgb.c_str(tmp_filename)) lgb.c_str(tmp_filename))
# Get predictions from file # Get predictions from file
...@@ -121,7 +123,8 @@ Predictor <- R6Class( ...@@ -121,7 +123,8 @@ Predictor <- R6Class(
as.integer(ncol(data)), as.integer(ncol(data)),
as.integer(rawscore), as.integer(rawscore),
as.integer(predleaf), as.integer(predleaf),
as.integer(num_iteration)) as.integer(num_iteration),
private$params)
} else if (is(data, "dgCMatrix")) { } else if (is(data, "dgCMatrix")) {
...@@ -137,7 +140,8 @@ Predictor <- R6Class( ...@@ -137,7 +140,8 @@ Predictor <- R6Class(
nrow(data), nrow(data),
as.integer(rawscore), as.integer(rawscore),
as.integer(predleaf), as.integer(predleaf),
as.integer(num_iteration)) as.integer(num_iteration),
private$params)
} else { } else {
...@@ -178,5 +182,6 @@ Predictor <- R6Class( ...@@ -178,5 +182,6 @@ Predictor <- R6Class(
), ),
private = list(handle = NULL, private = list(handle = NULL,
need_free_handle = FALSE) need_free_handle = FALSE,
params = "")
) )
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "../../src/boosting/boosting.cpp" #include "../../src/boosting/boosting.cpp"
#include "../../src/boosting/gbdt.cpp" #include "../../src/boosting/gbdt.cpp"
#include "../../src/boosting/gbdt_prediction.cpp" #include "../../src/boosting/gbdt_prediction.cpp"
#include "../../src/boosting/prediction_early_stop.cpp"
// io // io
#include "../../src/io/bin.cpp" #include "../../src/io/bin.cpp"
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "./src/boosting/boosting.cpp" #include "./src/boosting/boosting.cpp"
#include "./src/boosting/gbdt.cpp" #include "./src/boosting/gbdt.cpp"
#include "./src/boosting/gbdt_prediction.cpp" #include "./src/boosting/gbdt_prediction.cpp"
#include "./src/boosting/prediction_early_stop.cpp"
// io // io
#include "./src/io/bin.cpp" #include "./src/io/bin.cpp"
......
...@@ -498,12 +498,13 @@ LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle, ...@@ -498,12 +498,13 @@ LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE result_filename, LGBM_SE result_filename,
LGBM_SE call_state) { LGBM_SE call_state) {
R_API_BEGIN(); R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx); int pred_type = GetPredictType(is_rawscore, is_leafidx);
CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename), CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
R_CHAR_PTR(result_filename))); R_CHAR_PTR(result_filename)));
R_API_END(); R_API_END();
} }
...@@ -534,6 +535,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle, ...@@ -534,6 +535,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result, LGBM_SE out_result,
LGBM_SE call_state) { LGBM_SE call_state) {
...@@ -552,7 +554,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle, ...@@ -552,7 +554,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle), CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
p_indptr, C_API_DTYPE_INT32, p_indices, p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata, p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, pred_type, R_AS_INT(num_iteration), &out_len, ptr_ret)); nrow, pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
R_API_END(); R_API_END();
} }
...@@ -563,6 +565,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle, ...@@ -563,6 +565,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result, LGBM_SE out_result,
LGBM_SE call_state) { LGBM_SE call_state) {
...@@ -577,7 +580,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle, ...@@ -577,7 +580,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
int64_t out_len; int64_t out_len;
CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle), CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR, p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
pred_type, R_AS_INT(num_iteration), &out_len, ptr_ret)); pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
R_API_END(); R_API_END();
} }
......
...@@ -389,6 +389,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle, ...@@ -389,6 +389,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE result_filename, LGBM_SE result_filename,
LGBM_SE call_state); LGBM_SE call_state);
...@@ -438,6 +439,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle, ...@@ -438,6 +439,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result, LGBM_SE out_result,
LGBM_SE call_state); LGBM_SE call_state);
...@@ -463,6 +465,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle, ...@@ -463,6 +465,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result, LGBM_SE out_result,
LGBM_SE call_state); LGBM_SE call_state);
......
...@@ -192,6 +192,12 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s ...@@ -192,6 +192,12 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s
* `num_iteration_predict`, default=`-1`, type=int * `num_iteration_predict`, default=`-1`, type=int
* only used in prediction task, used to how many trained iterations will be used in prediction. * only used in prediction task, used to how many trained iterations will be used in prediction.
* `<= 0` means no limit * `<= 0` means no limit
* `pred_early_stop`, default=`false`, type=bool
* Set to `true` will use early-stopping to speed up the prediction. May affect the accuracy.
* `pred_early_stop_freq`, default=`10`, type=int
* The frequency of checking early-stopping prediction.
* `pred_early_stop_margin`, default=`10.0`, type=double
* The Threshold of margin in early-stopping prediction.
* `use_missing`, default=`true`, type=bool * `use_missing`, default=`true`, type=bool
* Set to `false` will disbale the special handle of missing value. * Set to `false` will disbale the special handle of missing value.
......
...@@ -117,19 +117,19 @@ public: ...@@ -117,19 +117,19 @@ public:
* \brief Prediction for one record, not sigmoid transform * \brief Prediction for one record, not sigmoid transform
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param output Prediction result for this record * \param output Prediction result for this record
* \param earlyStop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated. * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
*/ */
virtual void PredictRaw(const double* features, double* output, virtual void PredictRaw(const double* features, double* output,
const PredictionEarlyStopInstance* earlyStop = nullptr) const = 0; const PredictionEarlyStopInstance* early_stop) const = 0;
/*! /*!
* \brief Prediction for one record, sigmoid transformation will be used if needed * \brief Prediction for one record, sigmoid transformation will be used if needed
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param output Prediction result for this record * \param output Prediction result for this record
* \param earlyStop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated. * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
*/ */
virtual void Predict(const double* features, double* output, virtual void Predict(const double* features, double* output,
const PredictionEarlyStopInstance* earlyStop = nullptr) const = 0; const PredictionEarlyStopInstance* early_stop) const = 0;
/*! /*!
* \brief Prediction for one record with leaf index * \brief Prediction for one record with leaf index
...@@ -220,6 +220,9 @@ public: ...@@ -220,6 +220,9 @@ public:
*/ */
virtual int NumberOfClasses() const = 0; virtual int NumberOfClasses() const = 0;
/*! \brief The prediction should be accurate or not. True will disable early stopping for prediction. */
virtual bool NeedAccuratePrediction() const = 0;
/*! /*!
* \brief Initial work for the prediction * \brief Initial work for the prediction
* \param num_iteration number of used iteration * \param num_iteration number of used iteration
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
typedef void* DatasetHandle; typedef void* DatasetHandle;
typedef void* BoosterHandle; typedef void* BoosterHandle;
typedef void* PredictionEarlyStoppingHandle;
#define C_API_DTYPE_FLOAT32 (0) #define C_API_DTYPE_FLOAT32 (0)
#define C_API_DTYPE_FLOAT64 (1) #define C_API_DTYPE_FLOAT64 (1)
...@@ -522,7 +521,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -522,7 +521,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
int data_has_header, int data_has_header,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle, const char* parameter,
const char* result_filename); const char* result_filename);
/*! /*!
...@@ -578,7 +577,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -578,7 +577,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t num_col, int64_t num_col,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle, const char* parameter,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
...@@ -617,7 +616,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -617,7 +616,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
int64_t num_row, int64_t num_row,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle, const char* parameter,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
...@@ -650,7 +649,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -650,7 +649,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle, const char* parameter,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
...@@ -721,25 +720,6 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle, ...@@ -721,25 +720,6 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
int leaf_idx, int leaf_idx,
double val); double val);
/*!
* \brief create an new prediction early stopping instance that can be used to speed up prediction
* \param type early stopping type: "none", "multiclass" or "binary"
* \param round_period how often the classifier score is checked for the early stopping condition
* \param margin_threshold when the margin exceeds this value, early stopping kicks in and no more trees are evaluated
* \param out handle of created instance
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_PredictionEarlyStopInstanceCreate(const char* type,
int round_period,
double margin_threshold,
PredictionEarlyStoppingHandle* out);
/*!
\brief free prediction early stop instance
\return 0 when succeed
*/
LIGHTGBM_C_EXPORT int LGBM_PredictionEarlyStopInstanceFree(const PredictionEarlyStoppingHandle handle);
#if defined(_MSC_VER) #if defined(_MSC_VER)
// exception handle and error msg // exception handle and error msg
static char* LastErrorMsg() { static __declspec(thread) char err_msg[512] = "Everything is fine"; return err_msg; } static char* LastErrorMsg() { static __declspec(thread) char err_msg[512] = "Everything is fine"; return err_msg; }
...@@ -747,6 +727,7 @@ static char* LastErrorMsg() { static __declspec(thread) char err_msg[512] = "Eve ...@@ -747,6 +727,7 @@ static char* LastErrorMsg() { static __declspec(thread) char err_msg[512] = "Eve
static char* LastErrorMsg() { static thread_local char err_msg[512] = "Everything is fine"; return err_msg; } static char* LastErrorMsg() { static thread_local char err_msg[512] = "Everything is fine"; return err_msg; }
#endif #endif
#pragma warning(disable : 4996)
inline void LGBM_SetLastError(const char* msg) { inline void LGBM_SetLastError(const char* msg) {
std::strcpy(LastErrorMsg(), msg); std::strcpy(LastErrorMsg(), msg);
} }
......
...@@ -135,6 +135,14 @@ public: ...@@ -135,6 +135,14 @@ public:
* Note: when using Index, it doesn't count the label index */ * Note: when using Index, it doesn't count the label index */
std::string categorical_column = ""; std::string categorical_column = "";
std::string device_type = "cpu"; std::string device_type = "cpu";
/*! \brief Set to true if want to use early stop for the prediction */
bool pred_early_stop = false;
/*! \brief Frequency of checking the pred_early_stop */
int pred_early_stop_freq = 10;
/*! \brief Threshold of margin of pred_early_stop */
double pred_early_stop_margin = 10.0f;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override; LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
private: private:
void GetDeviceType(const std::unordered_map<std::string, void GetDeviceType(const std::unordered_map<std::string,
......
...@@ -43,6 +43,9 @@ public: ...@@ -43,6 +43,9 @@ public:
virtual int NumPredictOneRow() const { return 1; } virtual int NumPredictOneRow() const { return 1; }
/*! \brief The prediction should be accurate or not. True will disable early stopping for prediction. */
virtual bool NeedAccuratePrediction() const { return true; }
virtual void ConvertOutput(const double* input, double* output) const { virtual void ConvertOutput(const double* input, double* output) const {
output[0] = input[0]; output[0] = input[0];
} }
......
...@@ -6,27 +6,27 @@ ...@@ -6,27 +6,27 @@
#include <LightGBM/export.h> #include <LightGBM/export.h>
namespace LightGBM namespace LightGBM {
{
struct PredictionEarlyStopInstance #pragma warning(disable : 4099)
{ struct PredictionEarlyStopInstance {
/// Callback function type for early stopping. /// Callback function type for early stopping.
/// Takes current prediction and number of elements in prediction /// Takes current prediction and number of elements in prediction
/// @returns true if prediction should stop according to criterion /// @returns true if prediction should stop according to criterion
using FunctionType = std::function<bool(const double*, int)>; using FunctionType = std::function<bool(const double*, int)>;
FunctionType callbackFunction; // callback function itself FunctionType callback_function; // callback function itself
int roundPeriod; // call callbackFunction every `runPeriod` iterations int round_period; // call callback_function every `runPeriod` iterations
}; };
struct PredictionEarlyStopConfig #pragma warning(disable : 4099)
{ struct PredictionEarlyStopConfig {
int roundPeriod; int round_period;
double marginThreshold; double margin_threshold;
}; };
/// Create an early stopping algorithm of type `type`, with given roundPeriod and margin threshold /// Create an early stopping algorithm of type `type`, with given round_period and margin threshold
LIGHTGBM_EXPORT PredictionEarlyStopInstance createPredictionEarlyStopInstance(const std::string& type, LIGHTGBM_EXPORT PredictionEarlyStopInstance CreatePredictionEarlyStopInstance(const std::string& type,
const PredictionEarlyStopConfig& config); const PredictionEarlyStopConfig& config);
} // namespace LightGBM } // namespace LightGBM
......
...@@ -508,7 +508,7 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) { ...@@ -508,7 +508,7 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
// Buffer for merge. // Buffer for merge.
std::vector<_VTRanIt> temp_buf(len); std::vector<_VTRanIt> temp_buf(len);
_RanIt buf = temp_buf.begin(); _RanIt buf = temp_buf.begin();
int s = inner_size; size_t s = inner_size;
// Recursive merge // Recursive merge
while (s < len) { while (s < len) {
int loop_size = static_cast<int>((len + s * 2 - 1) / (s * 2)); int loop_size = static_cast<int>((len + s * 2 - 1) / (s * 2));
......
...@@ -6,7 +6,7 @@ Contributors: https://github.com/Microsoft/LightGBM/graphs/contributors ...@@ -6,7 +6,7 @@ Contributors: https://github.com/Microsoft/LightGBM/graphs/contributors
from __future__ import absolute_import from __future__ import absolute_import
from .basic import Booster, Dataset, PredictionEarlyStopInstance from .basic import Booster, Dataset
from .callback import (early_stopping, print_evaluation, record_evaluation, from .callback import (early_stopping, print_evaluation, record_evaluation,
reset_parameter) reset_parameter)
from .engine import cv, train from .engine import cv, train
...@@ -23,7 +23,7 @@ except ImportError: ...@@ -23,7 +23,7 @@ except ImportError:
__version__ = 0.2 __version__ = 0.2
__all__ = ['Dataset', 'Booster', 'PredictionEarlyStopInstance', __all__ = ['Dataset', 'Booster',
'train', 'cv', 'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker', 'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
......
...@@ -296,7 +296,7 @@ class _InnerPredictor(object): ...@@ -296,7 +296,7 @@ class _InnerPredictor(object):
Only used for prediction, usually used for continued-train Only used for prediction, usually used for continued-train
Note: Can convert from Booster, but cannot convert to Booster Note: Can convert from Booster, but cannot convert to Booster
""" """
def __init__(self, model_file=None, booster_handle=None, early_stop_instance=None): def __init__(self, model_file=None, booster_handle=None, pred_parameter=None):
"""Initialize the _InnerPredictor. Not expose to user """Initialize the _InnerPredictor. Not expose to user
Parameters Parameters
...@@ -305,8 +305,8 @@ class _InnerPredictor(object): ...@@ -305,8 +305,8 @@ class _InnerPredictor(object):
Path to the model file. Path to the model file.
booster_handle : Handle of Booster booster_handle : Handle of Booster
use handle to init use handle to init
early_stop_instance: object of type PredictionEarlyStopInstance pred_parameter: dict
If None, no early stopping is applied Other parameters for the prediciton
""" """
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
self.__is_manage_handle = True self.__is_manage_handle = True
...@@ -341,10 +341,8 @@ class _InnerPredictor(object): ...@@ -341,10 +341,8 @@ class _InnerPredictor(object):
else: else:
raise TypeError('Need Model file or Booster handle to create a predictor') raise TypeError('Need Model file or Booster handle to create a predictor')
if early_stop_instance is None: pred_parameter = {} if pred_parameter is None else pred_parameter
self.early_stop_instance = PredictionEarlyStopInstance("none") self.pred_parameter = param_dict_to_str(pred_parameter)
else:
self.early_stop_instance = early_stop_instance
def __del__(self): def __del__(self):
if self.__is_manage_handle: if self.__is_manage_handle:
...@@ -401,7 +399,7 @@ class _InnerPredictor(object): ...@@ -401,7 +399,7 @@ class _InnerPredictor(object):
ctypes.c_int(int_data_has_header), ctypes.c_int(int_data_has_header),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
self.early_stop_instance.handle, c_str(self.pred_parameter),
c_str(f.name))) c_str(f.name)))
lines = f.readlines() lines = f.readlines()
nrow = len(lines) nrow = len(lines)
...@@ -475,7 +473,7 @@ class _InnerPredictor(object): ...@@ -475,7 +473,7 @@ class _InnerPredictor(object):
ctypes.c_int(C_API_IS_ROW_MAJOR), ctypes.c_int(C_API_IS_ROW_MAJOR),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
self.early_stop_instance.handle, c_str(self.pred_parameter),
ctypes.byref(out_num_preds), ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if n_preds != out_num_preds.value: if n_preds != out_num_preds.value:
...@@ -506,7 +504,7 @@ class _InnerPredictor(object): ...@@ -506,7 +504,7 @@ class _InnerPredictor(object):
ctypes.c_int64(csr.shape[1]), ctypes.c_int64(csr.shape[1]),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
self.early_stop_instance.handle, c_str(self.pred_parameter),
ctypes.byref(out_num_preds), ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if n_preds != out_num_preds.value: if n_preds != out_num_preds.value:
...@@ -537,7 +535,7 @@ class _InnerPredictor(object): ...@@ -537,7 +535,7 @@ class _InnerPredictor(object):
ctypes.c_int64(csc.shape[0]), ctypes.c_int64(csc.shape[0]),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
self.early_stop_instance.handle, c_str(self.pred_parameter),
ctypes.byref(out_num_preds), ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if n_preds != out_num_preds.value: if n_preds != out_num_preds.value:
...@@ -1581,7 +1579,7 @@ class Booster(object): ...@@ -1581,7 +1579,7 @@ class Booster(object):
return json.loads(string_buffer.value.decode()) return json.loads(string_buffer.value.decode())
def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True, def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True,
early_stop_instance=None): pred_parameter=None):
""" """
Predict logic Predict logic
...@@ -1600,21 +1598,21 @@ class Booster(object): ...@@ -1600,21 +1598,21 @@ class Booster(object):
Used for txt data Used for txt data
is_reshape : bool is_reshape : bool
Reshape to (nrow, ncol) if true Reshape to (nrow, ncol) if true
early_stop_instance: object of type PredictionEarlyStopInstance. pred_parameter: dict
If None, no early stopping is applied Other parameters for the prediction
Returns Returns
------- -------
Prediction result Prediction result
""" """
predictor = self._to_predictor(early_stop_instance) predictor = self._to_predictor(pred_parameter)
if num_iteration <= 0: if num_iteration <= 0:
num_iteration = self.best_iteration num_iteration = self.best_iteration
return predictor.predict(data, num_iteration, raw_score, pred_leaf, data_has_header, is_reshape) return predictor.predict(data, num_iteration, raw_score, pred_leaf, data_has_header, is_reshape)
def _to_predictor(self, early_stop_instance=None): def _to_predictor(self, pred_parameter=None):
"""Convert to predictor""" """Convert to predictor"""
predictor = _InnerPredictor(booster_handle=self.handle, early_stop_instance=early_stop_instance) predictor = _InnerPredictor(booster_handle=self.handle, pred_parameter=pred_parameter)
predictor.pandas_categorical = self.pandas_categorical predictor.pandas_categorical = self.pandas_categorical
return predictor return predictor
...@@ -1800,35 +1798,3 @@ class Booster(object): ...@@ -1800,35 +1798,3 @@ class Booster(object):
self.__attr[key] = value self.__attr[key] = value
else: else:
self.__attr.pop(key, None) self.__attr.pop(key, None)
class PredictionEarlyStopInstance(object):
""""PredictionEarlyStopInstance in LightGBM."""
def __init__(self, early_stop_type="none", round_period=20, margin_threshold=1.5):
"""
Create an early stopping object
Parameters
----------
early_stop_type: string
None, "none", "binary" or "multiclass". Regression is not supported.
round_period : int
The score will be checked every round_period to check if the early stopping criteria is met
margin_threshold : double
Early stopping will kick in when the margin is greater than margin_threshold
"""
self.handle = ctypes.c_void_p(0)
self.__attr = {}
if early_stop_type is None:
early_stop_type = "none"
_safe_call(_LIB.LGBM_PredictionEarlyStopInstanceCreate(
c_str(early_stop_type),
ctypes.c_int(round_period),
ctypes.c_double(margin_threshold),
ctypes.byref(self.handle)))
def __del__(self):
if self.handle is not None:
_safe_call(_LIB.LGBM_PredictionEarlyStopInstanceFree(self.handle))
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <LightGBM/dataset_loader.h> #include <LightGBM/dataset_loader.h>
#include <LightGBM/boosting.h> #include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h> #include <LightGBM/objective_function.h>
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/metric.h> #include <LightGBM/metric.h>
#include "predictor.hpp" #include "predictor.hpp"
...@@ -107,9 +108,10 @@ void Application::LoadData() { ...@@ -107,9 +108,10 @@ void Application::LoadData() {
std::unique_ptr<Predictor> predictor; std::unique_ptr<Predictor> predictor;
// prediction is needed if using input initial model(continued train) // prediction is needed if using input initial model(continued train)
PredictFunction predict_fun = nullptr; PredictFunction predict_fun = nullptr;
PredictionEarlyStopInstance pred_early_stop = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
// need to continue training // need to continue training
if (boosting_->NumberOfTotalModel() > 0) { if (boosting_->NumberOfTotalModel() > 0) {
predictor.reset(new Predictor(boosting_.get(), -1, true, false)); predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, -1, -1));
predict_fun = predictor->GetPredictFunction(); predict_fun = predictor->GetPredictFunction();
} }
...@@ -250,7 +252,8 @@ void Application::Train() { ...@@ -250,7 +252,8 @@ void Application::Train() {
void Application::Predict() { void Application::Predict() {
// create predictor // create predictor
Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score, Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score,
config_.io_config.is_predict_leaf_index); config_.io_config.is_predict_leaf_index, config_.io_config.pred_early_stop,
config_.io_config.pred_early_stop_freq, config_.io_config.pred_early_stop_margin);
predictor.Predict(config_.io_config.data_filename.c_str(), predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header); config_.io_config.output_result.c_str(), config_.io_config.has_header);
Log::Info("Finished prediction"); Log::Info("Finished prediction");
......
...@@ -32,7 +32,20 @@ public: ...@@ -32,7 +32,20 @@ public:
*/ */
Predictor(Boosting* boosting, int num_iteration, Predictor(Boosting* boosting, int num_iteration,
bool is_raw_score, bool is_predict_leaf_index, bool is_raw_score, bool is_predict_leaf_index,
const PredictionEarlyStopInstance* earlyStop = nullptr) { bool early_stop, int early_stop_freq, double early_stop_margin) {
early_stop_ = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
if (early_stop && !boosting->NeedAccuratePrediction()) {
PredictionEarlyStopConfig pred_early_stop_config;
pred_early_stop_config.margin_threshold = early_stop_margin;
pred_early_stop_config.round_period = early_stop_freq;
if (boosting->NumberOfClasses() == 1) {
early_stop_ = CreatePredictionEarlyStopInstance("binary", pred_early_stop_config);
} else {
early_stop_ = CreatePredictionEarlyStopInstance("multiclass", pred_early_stop_config);
}
}
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
{ {
...@@ -55,17 +68,17 @@ public: ...@@ -55,17 +68,17 @@ public:
} else { } else {
if (is_raw_score) { if (is_raw_score) {
predict_fun_ = [this, earlyStop](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features); CopyToPredictBuffer(predict_buf_[tid].data(), features);
boosting_->PredictRaw(predict_buf_[tid].data(), output, earlyStop); boosting_->PredictRaw(predict_buf_[tid].data(), output, &early_stop_);
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features); ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
}; };
} else { } else {
predict_fun_ = [this, earlyStop](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features); CopyToPredictBuffer(predict_buf_[tid].data(), features);
boosting_->Predict(predict_buf_[tid].data(), output, earlyStop); boosting_->Predict(predict_buf_[tid].data(), output, &early_stop_);
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features); ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
}; };
} }
...@@ -117,7 +130,11 @@ public: ...@@ -117,7 +130,11 @@ public:
[this, &parser_fun, &result_file] [this, &parser_fun, &result_file]
(data_size_t, const std::vector<std::string>& lines) { (data_size_t, const std::vector<std::string>& lines) {
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, double>> oneline_features;
std::vector<std::string> result_to_write(lines.size());
OMP_INIT_EX();
#pragma omp parallel for schedule(static) firstprivate(oneline_features)
for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) { for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
OMP_LOOP_EX_BEGIN();
oneline_features.clear(); oneline_features.clear();
// parser // parser
parser_fun(lines[i].c_str(), &oneline_features); parser_fun(lines[i].c_str(), &oneline_features);
...@@ -125,7 +142,12 @@ public: ...@@ -125,7 +142,12 @@ public:
std::vector<double> result(num_pred_one_row_); std::vector<double> result(num_pred_one_row_);
predict_fun_(oneline_features, result.data()); predict_fun_(oneline_features, result.data());
auto str_result = Common::Join<double>(result, "\t"); auto str_result = Common::Join<double>(result, "\t");
fprintf(result_file, "%s\n", str_result.c_str()); result_to_write[i] = str_result;
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
for (data_size_t i = 0; i < static_cast<data_size_t>(result_to_write.size()); ++i) {
fprintf(result_file, "%s\n", result_to_write[i].c_str());
} }
}; };
TextReader<data_size_t> predict_data_reader(data_filename, has_header); TextReader<data_size_t> predict_data_reader(data_filename, has_header);
...@@ -137,7 +159,6 @@ private: ...@@ -137,7 +159,6 @@ private:
void CopyToPredictBuffer(double* pred_buf, const std::vector<std::pair<int, double>>& features) { void CopyToPredictBuffer(double* pred_buf, const std::vector<std::pair<int, double>>& features) {
int loop_size = static_cast<int>(features.size()); int loop_size = static_cast<int>(features.size());
#pragma omp parallel for schedule(static,128) if (loop_size >= 256)
for (int i = 0; i < loop_size; ++i) { for (int i = 0; i < loop_size; ++i) {
if (features[i].first < num_feature_) { if (features[i].first < num_feature_) {
pred_buf[features[i].first] = features[i].second; pred_buf[features[i].first] = features[i].second;
...@@ -150,7 +171,6 @@ private: ...@@ -150,7 +171,6 @@ private:
std::memset(pred_buf, 0, sizeof(double)*(buf_size)); std::memset(pred_buf, 0, sizeof(double)*(buf_size));
} else { } else {
int loop_size = static_cast<int>(features.size()); int loop_size = static_cast<int>(features.size());
#pragma omp parallel for schedule(static,128) if (loop_size >= 256)
for (int i = 0; i < loop_size; ++i) { for (int i = 0; i < loop_size; ++i) {
pred_buf[features[i].first] = 0.0f; pred_buf[features[i].first] = 0.0f;
} }
...@@ -161,6 +181,7 @@ private: ...@@ -161,6 +181,7 @@ private:
const Boosting* boosting_; const Boosting* boosting_;
/*! \brief function for prediction */ /*! \brief function for prediction */
PredictFunction predict_fun_; PredictFunction predict_fun_;
PredictionEarlyStopInstance early_stop_;
int num_feature_; int num_feature_;
int num_pred_one_row_; int num_pred_one_row_;
int num_threads_; int num_threads_;
......
...@@ -262,8 +262,6 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t ...@@ -262,8 +262,6 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t
return cur_left_cnt; return cur_left_cnt;
} }
void GBDT::Bagging(int iter) { void GBDT::Bagging(int iter) {
// if need bagging // if need bagging
if (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) { if (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) {
...@@ -738,32 +736,27 @@ std::string GBDT::ModelToIfElse(int num_iteration) const { ...@@ -738,32 +736,27 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
std::stringstream pred_str_buf; std::stringstream pred_str_buf;
pred_str_buf << "\t" << "const auto noEarlyStop = createPredictionEarlyStopInstance(\"none\", PredictionEarlyStopConfig());" << std::endl; pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << std::endl;
pred_str_buf << "\t" << "if (earlyStop == nullptr) {" << std::endl;
pred_str_buf << "\t\t" << "earlyStop = &noEarlyStop;" << std::endl;
pred_str_buf << "\t" << "}" << std::endl;
pred_str_buf << "\t" << "int earlyStopRoundCounter = 0;" << std::endl;
pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl; pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl; pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl; pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl; pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t\t" << "++earlyStopRoundCounter;" << std::endl; pred_str_buf << "\t\t" << "++early_stop_round_counter;" << std::endl;
pred_str_buf << "\t\t" << "if (earlyStop->roundPeriod == earlyStopRoundCounter) {" << std::endl; pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << std::endl;
pred_str_buf << "\t\t\t" << "if (earlyStop->callbackFunction(output, num_tree_per_iteration_))" << std::endl; pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << std::endl;
pred_str_buf << "\t\t\t\t" << "return;" << std::endl; pred_str_buf << "\t\t\t\t" << "return;" << std::endl;
pred_str_buf << "\t\t\t" << "earlyStopRoundCounter = 0;" << std::endl; pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl; pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t" << "}" << std::endl; pred_str_buf << "\t" << "}" << std::endl;
str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* earlyStop) const {" << std::endl; str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
str_buf << pred_str_buf.str(); str_buf << pred_str_buf.str();
str_buf << "}" << std::endl; str_buf << "}" << std::endl;
str_buf << std::endl; str_buf << std::endl;
// Predict // Predict
str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* earlyStop) const {" << std::endl; str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
str_buf << "\t" << "PredictRaw(features, output, earlyStop);" << std::endl; str_buf << "\t" << "PredictRaw(features, output, early_stop);" << std::endl;
str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl; str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl;
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl; str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
str_buf << "\t" << "}" << std::endl; str_buf << "\t" << "}" << std::endl;
...@@ -786,7 +779,6 @@ std::string GBDT::ModelToIfElse(int num_iteration) const { ...@@ -786,7 +779,6 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << std::endl; str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << std::endl;
str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl; str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl;
str_buf << "\t" << "#pragma omp parallel for schedule(static)" << std::endl;
str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl; str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl;
str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << std::endl; str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << std::endl;
str_buf << "\t" << "}" << std::endl; str_buf << "\t" << "}" << std::endl;
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#define LIGHTGBM_BOOSTING_GBDT_H_ #define LIGHTGBM_BOOSTING_GBDT_H_
#include <LightGBM/boosting.h> #include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include "score_updater.hpp" #include "score_updater.hpp"
#include <cstdio> #include <cstdio>
...@@ -93,6 +95,14 @@ public: ...@@ -93,6 +95,14 @@ public:
bool EvalAndCheckEarlyStopping() override; bool EvalAndCheckEarlyStopping() override;
bool NeedAccuratePrediction() const override {
if (objective_function_ == nullptr) {
return true;
} else {
return objective_function_->NeedAccuratePrediction();
}
}
/*! /*!
* \brief Get evaluation result at data_idx data * \brief Get evaluation result at data_idx data
* \param data_idx 0: training data, 1: 1st validation data * \param data_idx 0: training data, 1: 1st validation data
...@@ -137,7 +147,7 @@ public: ...@@ -137,7 +147,7 @@ public:
} }
void PredictRaw(const double* features, double* output, void PredictRaw(const double* features, double* output,
const PredictionEarlyStopInstance* earlyStop = nullptr) const override; const PredictionEarlyStopInstance* earlyStop) const override;
void Predict(const double* features, double* output, void Predict(const double* features, double* output,
const PredictionEarlyStopInstance* earlyStop) const override; const PredictionEarlyStopInstance* earlyStop) const override;
...@@ -365,7 +375,6 @@ protected: ...@@ -365,7 +375,6 @@ protected:
std::vector<double> class_default_output_; std::vector<double> class_default_output_;
bool is_constant_hessian_; bool is_constant_hessian_;
std::unique_ptr<ObjectiveFunction> loaded_objective_; std::unique_ptr<ObjectiveFunction> loaded_objective_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment