Unverified Commit 82e2ff7a authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

[Python] / [R] add start_iteration to python predict interface (fix #3058) (#3272)



* [python] add start_iteration to python predict interface (#3058)

* Apply suggestions from code review

* Update lightgbm_R.h

* Apply suggestions from code review

* Apply suggestions from code review

* fix R interface

* update R documentation
Co-authored-by: default avatarGuolin Ke <guolin.ke@outlook.com>
parent 083b02af
...@@ -483,6 +483,7 @@ Booster <- R6::R6Class( ...@@ -483,6 +483,7 @@ Booster <- R6::R6Class(
# Predict on new data # Predict on new data
predict = function(data, predict = function(data,
start_iteration = NULL,
num_iteration = NULL, num_iteration = NULL,
rawscore = FALSE, rawscore = FALSE,
predleaf = FALSE, predleaf = FALSE,
...@@ -494,10 +495,14 @@ Booster <- R6::R6Class( ...@@ -494,10 +495,14 @@ Booster <- R6::R6Class(
if (is.null(num_iteration)) { if (is.null(num_iteration)) {
num_iteration <- self$best_iter num_iteration <- self$best_iter
} }
# Check if start iteration is non existent
if (is.null(start_iteration)) {
start_iteration <- 0L
}
# Predict on new data # Predict on new data
predictor <- Predictor$new(private$handle, ...) predictor <- Predictor$new(private$handle, ...)
predictor$predict(data, num_iteration, rawscore, predleaf, predcontrib, header, reshape) predictor$predict(data, start_iteration, num_iteration, rawscore, predleaf, predcontrib, header, reshape)
}, },
...@@ -698,7 +703,14 @@ Booster <- R6::R6Class( ...@@ -698,7 +703,14 @@ Booster <- R6::R6Class(
#' @description Predicted values based on class \code{lgb.Booster} #' @description Predicted values based on class \code{lgb.Booster}
#' @param object Object of class \code{lgb.Booster} #' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename #' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration #' @param start_iteration int or None, optional (default=None)
#' Start index of the iteration to predict.
#' If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#' Limit number of iterations in the prediction.
#' If None, if the best iteration exists and start_iteration is None or <= 0, the
#' best iteration is used; otherwise, all iterations from start_iteration are used.
#' If <= 0, all iterations from start_iteration are used (no limits).
#' @param rawscore whether the prediction should be returned in the for of original untransformed #' @param rawscore whether the prediction should be returned in the for of original untransformed
#' sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} #' sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#' for logistic regression would result in predictions for log-odds instead of probabilities. #' for logistic regression would result in predictions for log-odds instead of probabilities.
...@@ -740,6 +752,7 @@ Booster <- R6::R6Class( ...@@ -740,6 +752,7 @@ Booster <- R6::R6Class(
#' @export #' @export
predict.lgb.Booster <- function(object, predict.lgb.Booster <- function(object,
data, data,
start_iteration = NULL,
num_iteration = NULL, num_iteration = NULL,
rawscore = FALSE, rawscore = FALSE,
predleaf = FALSE, predleaf = FALSE,
...@@ -756,6 +769,7 @@ predict.lgb.Booster <- function(object, ...@@ -756,6 +769,7 @@ predict.lgb.Booster <- function(object,
# Return booster predictions # Return booster predictions
object$predict( object$predict(
data data
, start_iteration
, num_iteration , num_iteration
, rawscore , rawscore
, predleaf , predleaf
......
...@@ -76,6 +76,7 @@ Predictor <- R6::R6Class( ...@@ -76,6 +76,7 @@ Predictor <- R6::R6Class(
# Predict from data # Predict from data
predict = function(data, predict = function(data,
start_iteration = NULL,
num_iteration = NULL, num_iteration = NULL,
rawscore = FALSE, rawscore = FALSE,
predleaf = FALSE, predleaf = FALSE,
...@@ -87,6 +88,10 @@ Predictor <- R6::R6Class( ...@@ -87,6 +88,10 @@ Predictor <- R6::R6Class(
if (is.null(num_iteration)) { if (is.null(num_iteration)) {
num_iteration <- -1L num_iteration <- -1L
} }
# Check if start iterations is existing - if not, then set it to 0 (start from the first iteration)
if (is.null(start_iteration)) {
start_iteration <- 0L
}
# Set temporary variable # Set temporary variable
num_row <- 0L num_row <- 0L
...@@ -108,6 +113,7 @@ Predictor <- R6::R6Class( ...@@ -108,6 +113,7 @@ Predictor <- R6::R6Class(
, as.integer(rawscore) , as.integer(rawscore)
, as.integer(predleaf) , as.integer(predleaf)
, as.integer(predcontrib) , as.integer(predcontrib)
, as.integer(start_iteration)
, as.integer(num_iteration) , as.integer(num_iteration)
, private$params , private$params
, lgb.c_str(tmp_filename) , lgb.c_str(tmp_filename)
...@@ -134,6 +140,7 @@ Predictor <- R6::R6Class( ...@@ -134,6 +140,7 @@ Predictor <- R6::R6Class(
, as.integer(rawscore) , as.integer(rawscore)
, as.integer(predleaf) , as.integer(predleaf)
, as.integer(predcontrib) , as.integer(predcontrib)
, as.integer(start_iteration)
, as.integer(num_iteration) , as.integer(num_iteration)
) )
...@@ -156,6 +163,7 @@ Predictor <- R6::R6Class( ...@@ -156,6 +163,7 @@ Predictor <- R6::R6Class(
, as.integer(rawscore) , as.integer(rawscore)
, as.integer(predleaf) , as.integer(predleaf)
, as.integer(predcontrib) , as.integer(predcontrib)
, as.integer(start_iteration)
, as.integer(num_iteration) , as.integer(num_iteration)
, private$params , private$params
) )
...@@ -178,6 +186,7 @@ Predictor <- R6::R6Class( ...@@ -178,6 +186,7 @@ Predictor <- R6::R6Class(
, as.integer(rawscore) , as.integer(rawscore)
, as.integer(predleaf) , as.integer(predleaf)
, as.integer(predcontrib) , as.integer(predcontrib)
, as.integer(start_iteration)
, as.integer(num_iteration) , as.integer(num_iteration)
, private$params , private$params
) )
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
\method{predict}{lgb.Booster}( \method{predict}{lgb.Booster}(
object, object,
data, data,
start_iteration = NULL,
num_iteration = NULL, num_iteration = NULL,
rawscore = FALSE, rawscore = FALSE,
predleaf = FALSE, predleaf = FALSE,
...@@ -21,7 +22,15 @@ ...@@ -21,7 +22,15 @@
\item{data}{a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename} \item{data}{a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename}
\item{num_iteration}{number of iteration want to predict with, NULL or <= 0 means use best iteration} \item{start_iteration}{int or None, optional (default=None)
Start index of the iteration to predict.
If None or <= 0, starts from the first iteration.}
\item{num_iteration}{int or None, optional (default=None)
Limit number of iterations in the prediction.
If None, if the best iteration exists and start_iteration is None or <= 0, the
best iteration is used; otherwise, all iterations from start_iteration are used.
If <= 0, all iterations from start_iteration are used (no limits).}
\item{rawscore}{whether the prediction should be returned in the for of original untransformed \item{rawscore}{whether the prediction should be returned in the for of original untransformed
sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
......
...@@ -541,6 +541,7 @@ LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle, ...@@ -541,6 +541,7 @@ LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib, LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE result_filename, LGBM_SE result_filename,
...@@ -548,7 +549,7 @@ LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle, ...@@ -548,7 +549,7 @@ LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
R_API_BEGIN(); R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename), CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), R_AS_INT(data_has_header), pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
R_CHAR_PTR(result_filename))); R_CHAR_PTR(result_filename)));
R_API_END(); R_API_END();
} }
...@@ -558,6 +559,7 @@ LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle, ...@@ -558,6 +559,7 @@ LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib, LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE out_len, LGBM_SE out_len,
LGBM_SE call_state) { LGBM_SE call_state) {
...@@ -565,7 +567,7 @@ LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle, ...@@ -565,7 +567,7 @@ LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int64_t len = 0; int64_t len = 0;
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row), CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row),
pred_type, R_AS_INT(num_iteration), &len)); pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), &len));
R_INT_PTR(out_len)[0] = static_cast<int>(len); R_INT_PTR(out_len)[0] = static_cast<int>(len);
R_API_END(); R_API_END();
} }
...@@ -580,6 +582,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle, ...@@ -580,6 +582,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib, LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE out_result, LGBM_SE out_result,
...@@ -599,7 +602,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle, ...@@ -599,7 +602,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle), CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
p_indptr, C_API_DTYPE_INT32, p_indices, p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata, p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret)); nrow, pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
R_API_END(); R_API_END();
} }
...@@ -610,6 +613,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle, ...@@ -610,6 +613,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib, LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE out_result, LGBM_SE out_result,
...@@ -625,7 +629,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle, ...@@ -625,7 +629,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
int64_t out_len; int64_t out_len;
CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle), CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR, p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret)); pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
R_API_END(); R_API_END();
} }
...@@ -706,10 +710,10 @@ static const R_CallMethodDef CallEntries[] = { ...@@ -706,10 +710,10 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterGetEval_R" , (DL_FUNC) &LGBM_BoosterGetEval_R , 4}, {"LGBM_BoosterGetEval_R" , (DL_FUNC) &LGBM_BoosterGetEval_R , 4},
{"LGBM_BoosterGetNumPredict_R" , (DL_FUNC) &LGBM_BoosterGetNumPredict_R , 4}, {"LGBM_BoosterGetNumPredict_R" , (DL_FUNC) &LGBM_BoosterGetNumPredict_R , 4},
{"LGBM_BoosterGetPredict_R" , (DL_FUNC) &LGBM_BoosterGetPredict_R , 4}, {"LGBM_BoosterGetPredict_R" , (DL_FUNC) &LGBM_BoosterGetPredict_R , 4},
{"LGBM_BoosterPredictForFile_R" , (DL_FUNC) &LGBM_BoosterPredictForFile_R , 10}, {"LGBM_BoosterPredictForFile_R" , (DL_FUNC) &LGBM_BoosterPredictForFile_R , 11},
{"LGBM_BoosterCalcNumPredict_R" , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R , 8}, {"LGBM_BoosterCalcNumPredict_R" , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R , 9},
{"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14}, {"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 15},
{"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11}, {"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 12},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 5}, {"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 5},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 7}, {"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 7},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 7}, {"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 7},
......
...@@ -489,6 +489,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForFile_R( ...@@ -489,6 +489,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForFile_R(
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib, LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE result_filename, LGBM_SE result_filename,
...@@ -511,6 +512,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterCalcNumPredict_R( ...@@ -511,6 +512,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterCalcNumPredict_R(
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib, LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE out_len, LGBM_SE out_len,
LGBM_SE call_state LGBM_SE call_state
...@@ -545,6 +547,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForCSC_R( ...@@ -545,6 +547,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForCSC_R(
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib, LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE out_result, LGBM_SE out_result,
...@@ -574,6 +577,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R( ...@@ -574,6 +577,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R(
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib, LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE out_result, LGBM_SE out_result,
......
...@@ -17,3 +17,63 @@ test_that("predictions do not fail for integer input", { ...@@ -17,3 +17,63 @@ test_that("predictions do not fail for integer input", {
pred_double <- predict(fit, X_double) pred_double <- predict(fit, X_double)
expect_equal(pred_integer, pred_double) expect_equal(pred_integer, pred_double)
}) })
test_that("start_iteration works correctly", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm")
train <- agaricus.train
test <- agaricus.test
dtrain <- lgb.Dataset(
agaricus.train$data
, label = agaricus.train$label
)
dtest <- lgb.Dataset.create.valid(
dtrain
, agaricus.test$data
, label = agaricus.test$label
)
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, num_leaves = 4L
, learning_rate = 0.6
, nrounds = 100L
, objective = "binary"
, save_name = tempfile(fileext = ".model")
, valids = list("test" = dtest)
, early_stopping_rounds = 2L
)
expect_true(lgb.is.Booster(bst))
pred1 <- predict(bst, data = test$data, rawscore = TRUE)
pred_contrib1 <- predict(bst, test$data, predcontrib = TRUE)
pred2 <- rep(0.0, length(pred1))
pred_contrib2 <- rep(0.0, length(pred2))
step <- 11L
end_iter <- 99L
if (bst$best_iter != -1L) {
end_iter <- bst$best_iter - 1L
}
start_iters <- seq(0L, end_iter, by = step)
for (start_iter in start_iters) {
n_iter <- min(c(end_iter - start_iter + 1L, step))
inc_pred <- predict(bst, test$data
, start_iteration = start_iter
, num_iteration = n_iter
, rawscore = TRUE
)
inc_pred_contrib <- bst$predict(test$data
, start_iteration = start_iter
, num_iteration = n_iter
, predcontrib = TRUE
)
pred2 <- pred2 + inc_pred
pred_contrib2 <- pred_contrib2 + inc_pred_contrib
}
expect_equal(pred2, pred1)
expect_equal(pred_contrib2, pred_contrib1)
pred_leaf1 <- predict(bst, test$data, predleaf = TRUE)
pred_leaf2 <- predict(bst, test$data, start_iteration = 0L, num_iteration = end_iter + 1L, predleaf = TRUE)
expect_equal(pred_leaf1, pred_leaf2)
})
...@@ -767,6 +767,14 @@ Dataset Parameters ...@@ -767,6 +767,14 @@ Dataset Parameters
Predict Parameters Predict Parameters
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
- ``start_iteration_predict`` :raw-html:`<a id="start_iteration_predict" title="Permalink to this parameter" href="#start_iteration_predict">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
- used only in ``prediction`` task
- used to specify from which iteration to start the prediction
- ``<= 0`` means from the first iteration
- ``num_iteration_predict`` :raw-html:`<a id="num_iteration_predict" title="Permalink to this parameter" href="#num_iteration_predict">&#x1F517;&#xFE0E;</a>`, default = ``-1``, type = int - ``num_iteration_predict`` :raw-html:`<a id="num_iteration_predict" title="Permalink to this parameter" href="#num_iteration_predict">&#x1F517;&#xFE0E;</a>`, default = ``-1``, type = int
- used only in ``prediction`` task - used only in ``prediction`` task
......
...@@ -123,7 +123,7 @@ class LIGHTGBM_EXPORT Boosting { ...@@ -123,7 +123,7 @@ class LIGHTGBM_EXPORT Boosting {
*/ */
virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0; virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0;
virtual int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const = 0; virtual int NumPredictOneRow(int start_iteration, int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const = 0;
/*! /*!
* \brief Prediction for one record, not sigmoid transform * \brief Prediction for one record, not sigmoid transform
...@@ -284,10 +284,11 @@ class LIGHTGBM_EXPORT Boosting { ...@@ -284,10 +284,11 @@ class LIGHTGBM_EXPORT Boosting {
/*! /*!
* \brief Initial work for the prediction * \brief Initial work for the prediction
* \param start_iteration Start index of the iteration to predict
* \param num_iteration number of used iteration * \param num_iteration number of used iteration
* \param is_pred_contrib * \param is_pred_contrib
*/ */
virtual void InitPredict(int num_iteration, bool is_pred_contrib) = 0; virtual void InitPredict(int start_iteration, int num_iteration, bool is_pred_contrib) = 0;
/*! /*!
* \brief Name of submodel * \brief Name of submodel
......
...@@ -675,6 +675,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle, ...@@ -675,6 +675,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
* - ``C_API_PREDICT_RAW_SCORE``: raw score; * - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index; * - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iterations for prediction, <= 0 means no limit * \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction * \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param result_filename Filename of result file in which predictions will be written * \param result_filename Filename of result file in which predictions will be written
...@@ -684,6 +685,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -684,6 +685,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* data_filename, const char* data_filename,
int data_has_header, int data_has_header,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
const char* result_filename); const char* result_filename);
...@@ -697,6 +699,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -697,6 +699,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
* - ``C_API_PREDICT_RAW_SCORE``: raw score; * - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index; * - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iterations for prediction, <= 0 means no limit * \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param[out] out_len Length of prediction * \param[out] out_len Length of prediction
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
...@@ -704,6 +707,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -704,6 +707,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
int num_row, int num_row,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
int64_t* out_len); int64_t* out_len);
...@@ -736,6 +740,7 @@ LIGHTGBM_C_EXPORT int LGBM_FastConfigFree(FastConfigHandle fastConfig); ...@@ -736,6 +740,7 @@ LIGHTGBM_C_EXPORT int LGBM_FastConfigFree(FastConfigHandle fastConfig);
* - ``C_API_PREDICT_RAW_SCORE``: raw score; * - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index; * - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iterations for prediction, <= 0 means no limit * \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction * \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output result * \param[out] out_len Length of output result
...@@ -752,6 +757,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -752,6 +757,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t nelem, int64_t nelem,
int64_t num_col, int64_t num_col,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
...@@ -775,6 +781,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -775,6 +781,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
* \param num_col_or_row Number of columns for CSR or number of rows for CSC * \param num_col_or_row Number of columns for CSR or number of rows for CSC
* \param predict_type What should be predicted, only feature contributions supported currently * \param predict_type What should be predicted, only feature contributions supported currently
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iterations for prediction, <= 0 means no limit * \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction * \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param matrix_type Type of matrix input and output, can be ``C_API_MATRIX_TYPE_CSR`` or ``C_API_MATRIX_TYPE_CSC`` * \param matrix_type Type of matrix input and output, can be ``C_API_MATRIX_TYPE_CSR`` or ``C_API_MATRIX_TYPE_CSC``
...@@ -794,6 +801,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictSparseOutput(BoosterHandle handle, ...@@ -794,6 +801,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictSparseOutput(BoosterHandle handle,
int64_t nelem, int64_t nelem,
int64_t num_col_or_row, int64_t num_col_or_row,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int matrix_type, int matrix_type,
...@@ -835,6 +843,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFreePredictSparse(void* indptr, int32_t* indic ...@@ -835,6 +843,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFreePredictSparse(void* indptr, int32_t* indic
* - ``C_API_PREDICT_RAW_SCORE``: raw score; * - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index; * - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iterations for prediction, <= 0 means no limit * \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction * \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output result * \param[out] out_len Length of output result
...@@ -851,6 +860,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, ...@@ -851,6 +860,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
int64_t nelem, int64_t nelem,
int64_t num_col, int64_t num_col,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
...@@ -867,6 +877,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, ...@@ -867,6 +877,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
* - ``C_API_PREDICT_RAW_SCORE``: raw score; * - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index; * - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iterations for prediction, <= 0 means no limit * \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64`` * \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param num_col Number of columns * \param num_col Number of columns
...@@ -876,6 +887,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, ...@@ -876,6 +887,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
const int predict_type, const int predict_type,
const int start_iteration,
const int num_iteration, const int num_iteration,
const int data_type, const int data_type,
const int64_t num_col, const int64_t num_col,
...@@ -944,6 +956,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fa ...@@ -944,6 +956,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fa
* - ``C_API_PREDICT_RAW_SCORE``: raw score; * - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index; * - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iteration for prediction, <= 0 means no limit * \param num_iteration Number of iteration for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction * \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output result * \param[out] out_len Length of output result
...@@ -960,6 +973,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -960,6 +973,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
int64_t nelem, int64_t nelem,
int64_t num_row, int64_t num_row,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
...@@ -983,6 +997,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -983,6 +997,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
* - ``C_API_PREDICT_RAW_SCORE``: raw score; * - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index; * - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iteration for prediction, <= 0 means no limit * \param num_iteration Number of iteration for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction * \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output result * \param[out] out_len Length of output result
...@@ -996,6 +1011,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -996,6 +1011,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
int32_t ncol, int32_t ncol,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
...@@ -1019,6 +1035,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -1019,6 +1035,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
* - ``C_API_PREDICT_RAW_SCORE``: raw score; * - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index; * - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iteration for prediction, <= 0 means no limit * \param num_iteration Number of iteration for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction * \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output result * \param[out] out_len Length of output result
...@@ -1031,6 +1048,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, ...@@ -1031,6 +1048,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
int ncol, int ncol,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
...@@ -1047,6 +1065,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, ...@@ -1047,6 +1065,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
* - ``C_API_PREDICT_RAW_SCORE``: raw score; * - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index; * - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iterations for prediction, <= 0 means no limit * \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64`` * \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param ncol Number of columns * \param ncol Number of columns
...@@ -1056,6 +1075,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, ...@@ -1056,6 +1075,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
const int predict_type, const int predict_type,
const int start_iteration,
const int num_iteration, const int num_iteration,
const int data_type, const int data_type,
const int32_t ncol, const int32_t ncol,
...@@ -1104,6 +1124,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fa ...@@ -1104,6 +1124,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fa
* - ``C_API_PREDICT_RAW_SCORE``: raw score; * - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index; * - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iteration for prediction, <= 0 means no limit * \param num_iteration Number of iteration for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction * \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output result * \param[out] out_len Length of output result
...@@ -1116,6 +1137,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMats(BoosterHandle handle, ...@@ -1116,6 +1137,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMats(BoosterHandle handle,
int32_t nrow, int32_t nrow,
int32_t ncol, int32_t ncol,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
......
...@@ -684,6 +684,12 @@ struct Config { ...@@ -684,6 +684,12 @@ struct Config {
#pragma region Predict Parameters #pragma region Predict Parameters
// [no-save]
// desc = used only in ``prediction`` task
// desc = used to specify from which iteration to start the prediction
// desc = ``<= 0`` means from the first iteration
int start_iteration_predict = 0;
// [no-save] // [no-save]
// desc = used only in ``prediction`` task // desc = used only in ``prediction`` task
// desc = used to specify how many trained iterations will be used in prediction // desc = used to specify how many trained iterations will be used in prediction
......
...@@ -519,7 +519,7 @@ class _InnerPredictor(object): ...@@ -519,7 +519,7 @@ class _InnerPredictor(object):
this.pop('handle', None) this.pop('handle', None)
return this return this
def predict(self, data, num_iteration=-1, def predict(self, data, start_iteration=0, num_iteration=-1,
raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False, raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False,
is_reshape=True): is_reshape=True):
"""Predict logic. """Predict logic.
...@@ -529,6 +529,8 @@ class _InnerPredictor(object): ...@@ -529,6 +529,8 @@ class _InnerPredictor(object):
data : string, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse data : string, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for prediction. Data source for prediction.
When data type is string, it represents the path of txt file. When data type is string, it represents the path of txt file.
start_iteration : int, optional (default=0)
Start index of the iteration to predict.
num_iteration : int, optional (default=-1) num_iteration : int, optional (default=-1)
Iteration used for prediction. Iteration used for prediction.
raw_score : bool, optional (default=False) raw_score : bool, optional (default=False)
...@@ -560,8 +562,6 @@ class _InnerPredictor(object): ...@@ -560,8 +562,6 @@ class _InnerPredictor(object):
if pred_contrib: if pred_contrib:
predict_type = C_API_PREDICT_CONTRIB predict_type = C_API_PREDICT_CONTRIB
int_data_has_header = 1 if data_has_header else 0 int_data_has_header = 1 if data_has_header else 0
if num_iteration > self.num_total_iteration:
num_iteration = self.num_total_iteration
if isinstance(data, string_type): if isinstance(data, string_type):
with _TempFile() as f: with _TempFile() as f:
...@@ -570,6 +570,7 @@ class _InnerPredictor(object): ...@@ -570,6 +570,7 @@ class _InnerPredictor(object):
c_str(data), c_str(data),
ctypes.c_int(int_data_has_header), ctypes.c_int(int_data_has_header),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
c_str(self.pred_parameter), c_str(self.pred_parameter),
c_str(f.name))) c_str(f.name)))
...@@ -578,26 +579,26 @@ class _InnerPredictor(object): ...@@ -578,26 +579,26 @@ class _InnerPredictor(object):
preds = [float(token) for line in lines for token in line.split('\t')] preds = [float(token) for line in lines for token in line.split('\t')]
preds = np.array(preds, dtype=np.float64, copy=False) preds = np.array(preds, dtype=np.float64, copy=False)
elif isinstance(data, scipy.sparse.csr_matrix): elif isinstance(data, scipy.sparse.csr_matrix):
preds, nrow = self.__pred_for_csr(data, num_iteration, predict_type) preds, nrow = self.__pred_for_csr(data, start_iteration, num_iteration, predict_type)
elif isinstance(data, scipy.sparse.csc_matrix): elif isinstance(data, scipy.sparse.csc_matrix):
preds, nrow = self.__pred_for_csc(data, num_iteration, predict_type) preds, nrow = self.__pred_for_csc(data, start_iteration, num_iteration, predict_type)
elif isinstance(data, np.ndarray): elif isinstance(data, np.ndarray):
preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type) preds, nrow = self.__pred_for_np2d(data, start_iteration, num_iteration, predict_type)
elif isinstance(data, list): elif isinstance(data, list):
try: try:
data = np.array(data) data = np.array(data)
except BaseException: except BaseException:
raise ValueError('Cannot convert data list to numpy array.') raise ValueError('Cannot convert data list to numpy array.')
preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type) preds, nrow = self.__pred_for_np2d(data, start_iteration, num_iteration, predict_type)
elif isinstance(data, DataTable): elif isinstance(data, DataTable):
preds, nrow = self.__pred_for_np2d(data.to_numpy(), num_iteration, predict_type) preds, nrow = self.__pred_for_np2d(data.to_numpy(), start_iteration, num_iteration, predict_type)
else: else:
try: try:
warnings.warn('Converting data to scipy sparse matrix.') warnings.warn('Converting data to scipy sparse matrix.')
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
except BaseException: except BaseException:
raise TypeError('Cannot predict data for type {}'.format(type(data).__name__)) raise TypeError('Cannot predict data for type {}'.format(type(data).__name__))
preds, nrow = self.__pred_for_csr(csr, num_iteration, predict_type) preds, nrow = self.__pred_for_csr(csr, start_iteration, num_iteration, predict_type)
if pred_leaf: if pred_leaf:
preds = preds.astype(np.int32) preds = preds.astype(np.int32)
is_sparse = scipy.sparse.issparse(preds) or isinstance(preds, list) is_sparse = scipy.sparse.issparse(preds) or isinstance(preds, list)
...@@ -609,7 +610,7 @@ class _InnerPredictor(object): ...@@ -609,7 +610,7 @@ class _InnerPredictor(object):
% (preds.size, nrow)) % (preds.size, nrow))
return preds return preds
def __get_num_preds(self, num_iteration, nrow, predict_type): def __get_num_preds(self, start_iteration, num_iteration, nrow, predict_type):
"""Get size of prediction result.""" """Get size of prediction result."""
if nrow > MAX_INT32: if nrow > MAX_INT32:
raise LightGBMError('LightGBM cannot perform prediction for data' raise LightGBMError('LightGBM cannot perform prediction for data'
...@@ -621,22 +622,23 @@ class _InnerPredictor(object): ...@@ -621,22 +622,23 @@ class _InnerPredictor(object):
self.handle, self.handle,
ctypes.c_int(nrow), ctypes.c_int(nrow),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
ctypes.byref(n_preds))) ctypes.byref(n_preds)))
return n_preds.value return n_preds.value
def __pred_for_np2d(self, mat, num_iteration, predict_type): def __pred_for_np2d(self, mat, start_iteration, num_iteration, predict_type):
"""Predict for a 2-D numpy matrix.""" """Predict for a 2-D numpy matrix."""
if len(mat.shape) != 2: if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray or list must be 2 dimensional') raise ValueError('Input numpy.ndarray or list must be 2 dimensional')
def inner_predict(mat, num_iteration, predict_type, preds=None): def inner_predict(mat, start_iteration, num_iteration, predict_type, preds=None):
if mat.dtype == np.float32 or mat.dtype == np.float64: if mat.dtype == np.float32 or mat.dtype == np.float64:
data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False) data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False)
else: # change non-float data to float data, need to copy else: # change non-float data to float data, need to copy
data = np.array(mat.reshape(mat.size), dtype=np.float32) data = np.array(mat.reshape(mat.size), dtype=np.float32)
ptr_data, type_ptr_data, _ = c_float_array(data) ptr_data, type_ptr_data, _ = c_float_array(data)
n_preds = self.__get_num_preds(num_iteration, mat.shape[0], predict_type) n_preds = self.__get_num_preds(start_iteration, num_iteration, mat.shape[0], predict_type)
if preds is None: if preds is None:
preds = np.zeros(n_preds, dtype=np.float64) preds = np.zeros(n_preds, dtype=np.float64)
elif len(preds.shape) != 1 or len(preds) != n_preds: elif len(preds.shape) != 1 or len(preds) != n_preds:
...@@ -650,6 +652,7 @@ class _InnerPredictor(object): ...@@ -650,6 +652,7 @@ class _InnerPredictor(object):
ctypes.c_int(mat.shape[1]), ctypes.c_int(mat.shape[1]),
ctypes.c_int(C_API_IS_ROW_MAJOR), ctypes.c_int(C_API_IS_ROW_MAJOR),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
c_str(self.pred_parameter), c_str(self.pred_parameter),
ctypes.byref(out_num_preds), ctypes.byref(out_num_preds),
...@@ -662,16 +665,16 @@ class _InnerPredictor(object): ...@@ -662,16 +665,16 @@ class _InnerPredictor(object):
if nrow > MAX_INT32: if nrow > MAX_INT32:
sections = np.arange(start=MAX_INT32, stop=nrow, step=MAX_INT32) sections = np.arange(start=MAX_INT32, stop=nrow, step=MAX_INT32)
# __get_num_preds() cannot work with nrow > MAX_INT32, so calculate overall number of predictions piecemeal # __get_num_preds() cannot work with nrow > MAX_INT32, so calculate overall number of predictions piecemeal
n_preds = [self.__get_num_preds(num_iteration, i, predict_type) for i in np.diff([0] + list(sections) + [nrow])] n_preds = [self.__get_num_preds(start_iteration, num_iteration, i, predict_type) for i in np.diff([0] + list(sections) + [nrow])]
n_preds_sections = np.array([0] + n_preds, dtype=np.intp).cumsum() n_preds_sections = np.array([0] + n_preds, dtype=np.intp).cumsum()
preds = np.zeros(sum(n_preds), dtype=np.float64) preds = np.zeros(sum(n_preds), dtype=np.float64)
for chunk, (start_idx_pred, end_idx_pred) in zip_(np.array_split(mat, sections), for chunk, (start_idx_pred, end_idx_pred) in zip_(np.array_split(mat, sections),
zip_(n_preds_sections, n_preds_sections[1:])): zip_(n_preds_sections, n_preds_sections[1:])):
# avoid memory consumption by arrays concatenation operations # avoid memory consumption by arrays concatenation operations
inner_predict(chunk, num_iteration, predict_type, preds[start_idx_pred:end_idx_pred]) inner_predict(chunk, start_iteration, num_iteration, predict_type, preds[start_idx_pred:end_idx_pred])
return preds, nrow return preds, nrow
else: else:
return inner_predict(mat, num_iteration, predict_type) return inner_predict(mat, start_iteration, num_iteration, predict_type)
def __create_sparse_native(self, cs, out_shape, out_ptr_indptr, out_ptr_indices, out_ptr_data, def __create_sparse_native(self, cs, out_shape, out_ptr_indptr, out_ptr_indices, out_ptr_data,
indptr_type, data_type, is_csr=True): indptr_type, data_type, is_csr=True):
...@@ -719,11 +722,11 @@ class _InnerPredictor(object): ...@@ -719,11 +722,11 @@ class _InnerPredictor(object):
return cs_output_matrices[0] return cs_output_matrices[0]
return cs_output_matrices return cs_output_matrices
def __pred_for_csr(self, csr, num_iteration, predict_type): def __pred_for_csr(self, csr, start_iteration, num_iteration, predict_type):
"""Predict for a CSR data.""" """Predict for a CSR data."""
def inner_predict(csr, num_iteration, predict_type, preds=None): def inner_predict(csr, start_iteration, num_iteration, predict_type, preds=None):
nrow = len(csr.indptr) - 1 nrow = len(csr.indptr) - 1
n_preds = self.__get_num_preds(num_iteration, nrow, predict_type) n_preds = self.__get_num_preds(start_iteration, num_iteration, nrow, predict_type)
if preds is None: if preds is None:
preds = np.zeros(n_preds, dtype=np.float64) preds = np.zeros(n_preds, dtype=np.float64)
elif len(preds.shape) != 1 or len(preds) != n_preds: elif len(preds.shape) != 1 or len(preds) != n_preds:
...@@ -747,6 +750,7 @@ class _InnerPredictor(object): ...@@ -747,6 +750,7 @@ class _InnerPredictor(object):
ctypes.c_int64(len(csr.data)), ctypes.c_int64(len(csr.data)),
ctypes.c_int64(csr.shape[1]), ctypes.c_int64(csr.shape[1]),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
c_str(self.pred_parameter), c_str(self.pred_parameter),
ctypes.byref(out_num_preds), ctypes.byref(out_num_preds),
...@@ -755,7 +759,7 @@ class _InnerPredictor(object): ...@@ -755,7 +759,7 @@ class _InnerPredictor(object):
raise ValueError("Wrong length for predict results") raise ValueError("Wrong length for predict results")
return preds, nrow return preds, nrow
def inner_predict_sparse(csr, num_iteration, predict_type): def inner_predict_sparse(csr, start_iteration, num_iteration, predict_type):
ptr_indptr, type_ptr_indptr, __ = c_int_array(csr.indptr) ptr_indptr, type_ptr_indptr, __ = c_int_array(csr.indptr)
ptr_data, type_ptr_data, _ = c_float_array(csr.data) ptr_data, type_ptr_data, _ = c_float_array(csr.data)
csr_indices = csr.indices.astype(np.int32, copy=False) csr_indices = csr.indices.astype(np.int32, copy=False)
...@@ -781,6 +785,7 @@ class _InnerPredictor(object): ...@@ -781,6 +785,7 @@ class _InnerPredictor(object):
ctypes.c_int64(len(csr.data)), ctypes.c_int64(len(csr.data)),
ctypes.c_int64(csr.shape[1]), ctypes.c_int64(csr.shape[1]),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
c_str(self.pred_parameter), c_str(self.pred_parameter),
ctypes.c_int(matrix_type), ctypes.c_int(matrix_type),
...@@ -794,25 +799,25 @@ class _InnerPredictor(object): ...@@ -794,25 +799,25 @@ class _InnerPredictor(object):
return matrices, nrow return matrices, nrow
if predict_type == C_API_PREDICT_CONTRIB: if predict_type == C_API_PREDICT_CONTRIB:
return inner_predict_sparse(csr, num_iteration, predict_type) return inner_predict_sparse(csr, start_iteration, num_iteration, predict_type)
nrow = len(csr.indptr) - 1 nrow = len(csr.indptr) - 1
if nrow > MAX_INT32: if nrow > MAX_INT32:
sections = [0] + list(np.arange(start=MAX_INT32, stop=nrow, step=MAX_INT32)) + [nrow] sections = [0] + list(np.arange(start=MAX_INT32, stop=nrow, step=MAX_INT32)) + [nrow]
# __get_num_preds() cannot work with nrow > MAX_INT32, so calculate overall number of predictions piecemeal # __get_num_preds() cannot work with nrow > MAX_INT32, so calculate overall number of predictions piecemeal
n_preds = [self.__get_num_preds(num_iteration, i, predict_type) for i in np.diff(sections)] n_preds = [self.__get_num_preds(start_iteration, num_iteration, i, predict_type) for i in np.diff(sections)]
n_preds_sections = np.array([0] + n_preds, dtype=np.intp).cumsum() n_preds_sections = np.array([0] + n_preds, dtype=np.intp).cumsum()
preds = np.zeros(sum(n_preds), dtype=np.float64) preds = np.zeros(sum(n_preds), dtype=np.float64)
for (start_idx, end_idx), (start_idx_pred, end_idx_pred) in zip_(zip_(sections, sections[1:]), for (start_idx, end_idx), (start_idx_pred, end_idx_pred) in zip_(zip_(sections, sections[1:]),
zip_(n_preds_sections, n_preds_sections[1:])): zip_(n_preds_sections, n_preds_sections[1:])):
# avoid memory consumption by arrays concatenation operations # avoid memory consumption by arrays concatenation operations
inner_predict(csr[start_idx:end_idx], num_iteration, predict_type, preds[start_idx_pred:end_idx_pred]) inner_predict(csr[start_idx:end_idx], start_iteration, num_iteration, predict_type, preds[start_idx_pred:end_idx_pred])
return preds, nrow return preds, nrow
else: else:
return inner_predict(csr, num_iteration, predict_type) return inner_predict(csr, start_iteration, num_iteration, predict_type)
def __pred_for_csc(self, csc, num_iteration, predict_type): def __pred_for_csc(self, csc, start_iteration, num_iteration, predict_type):
"""Predict for a CSC data.""" """Predict for a CSC data."""
def inner_predict_sparse(csc, num_iteration, predict_type): def inner_predict_sparse(csc, start_iteration, num_iteration, predict_type):
ptr_indptr, type_ptr_indptr, __ = c_int_array(csc.indptr) ptr_indptr, type_ptr_indptr, __ = c_int_array(csc.indptr)
ptr_data, type_ptr_data, _ = c_float_array(csc.data) ptr_data, type_ptr_data, _ = c_float_array(csc.data)
csc_indices = csc.indices.astype(np.int32, copy=False) csc_indices = csc.indices.astype(np.int32, copy=False)
...@@ -838,6 +843,7 @@ class _InnerPredictor(object): ...@@ -838,6 +843,7 @@ class _InnerPredictor(object):
ctypes.c_int64(len(csc.data)), ctypes.c_int64(len(csc.data)),
ctypes.c_int64(csc.shape[0]), ctypes.c_int64(csc.shape[0]),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
c_str(self.pred_parameter), c_str(self.pred_parameter),
ctypes.c_int(matrix_type), ctypes.c_int(matrix_type),
...@@ -852,10 +858,10 @@ class _InnerPredictor(object): ...@@ -852,10 +858,10 @@ class _InnerPredictor(object):
nrow = csc.shape[0] nrow = csc.shape[0]
if nrow > MAX_INT32: if nrow > MAX_INT32:
return self.__pred_for_csr(csc.tocsr(), num_iteration, predict_type) return self.__pred_for_csr(csc.tocsr(), start_iteration, num_iteration, predict_type)
if predict_type == C_API_PREDICT_CONTRIB: if predict_type == C_API_PREDICT_CONTRIB:
return inner_predict_sparse(csc, num_iteration, predict_type) return inner_predict_sparse(csc, start_iteration, num_iteration, predict_type)
n_preds = self.__get_num_preds(num_iteration, nrow, predict_type) n_preds = self.__get_num_preds(start_iteration, num_iteration, nrow, predict_type)
preds = np.zeros(n_preds, dtype=np.float64) preds = np.zeros(n_preds, dtype=np.float64)
out_num_preds = ctypes.c_int64(0) out_num_preds = ctypes.c_int64(0)
...@@ -876,6 +882,7 @@ class _InnerPredictor(object): ...@@ -876,6 +882,7 @@ class _InnerPredictor(object):
ctypes.c_int64(len(csc.data)), ctypes.c_int64(len(csc.data)),
ctypes.c_int64(csc.shape[0]), ctypes.c_int64(csc.shape[0]),
ctypes.c_int(predict_type), ctypes.c_int(predict_type),
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
c_str(self.pred_parameter), c_str(self.pred_parameter),
ctypes.byref(out_num_preds), ctypes.byref(out_num_preds),
...@@ -2806,7 +2813,7 @@ class Booster(object): ...@@ -2806,7 +2813,7 @@ class Booster(object):
default=json_default_with_numpy)) default=json_default_with_numpy))
return ret return ret
def predict(self, data, num_iteration=None, def predict(self, data, start_iteration=None, num_iteration=None,
raw_score=False, pred_leaf=False, pred_contrib=False, raw_score=False, pred_leaf=False, pred_contrib=False,
data_has_header=False, is_reshape=True, **kwargs): data_has_header=False, is_reshape=True, **kwargs):
"""Make a prediction. """Make a prediction.
...@@ -2816,10 +2823,14 @@ class Booster(object): ...@@ -2816,10 +2823,14 @@ class Booster(object):
data : string, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse data : string, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for prediction. Data source for prediction.
If string, it represents the path to txt file. If string, it represents the path to txt file.
start_iteration : int or None, optional (default=None)
Start index of the iteration to predict.
If None or <= 0, starts from the first iteration.
num_iteration : int or None, optional (default=None) num_iteration : int or None, optional (default=None)
Limit number of iterations in the prediction. Limit number of iterations in the prediction.
If None, if the best iteration exists, it is used; otherwise, all iterations are used. If None, if the best iteration exists and start_iteration is None or <= 0, the best iteration is used;
If <= 0, all iterations are used (no limits). otherwise, all iterations from start_iteration are used.
If <= 0, all iterations from start_iteration are used (no limits).
raw_score : bool, optional (default=False) raw_score : bool, optional (default=False)
Whether to predict raw scores. Whether to predict raw scores.
pred_leaf : bool, optional (default=False) pred_leaf : bool, optional (default=False)
...@@ -2850,9 +2861,14 @@ class Booster(object): ...@@ -2850,9 +2861,14 @@ class Booster(object):
Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``). Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
""" """
predictor = self._to_predictor(copy.deepcopy(kwargs)) predictor = self._to_predictor(copy.deepcopy(kwargs))
if start_iteration is None or start_iteration < 0:
start_iteration = 0
if num_iteration is None: if num_iteration is None:
if start_iteration == 0:
num_iteration = self.best_iteration num_iteration = self.best_iteration
return predictor.predict(data, num_iteration, else:
num_iteration = -1
return predictor.predict(data, start_iteration, num_iteration,
raw_score, pred_leaf, pred_contrib, raw_score, pred_leaf, pred_contrib,
data_has_header, is_reshape) data_has_header, is_reshape)
......
...@@ -612,7 +612,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -612,7 +612,7 @@ class LGBMModel(_LGBMModelBase):
del train_set, valid_sets del train_set, valid_sets
return self return self
def predict(self, X, raw_score=False, num_iteration=None, def predict(self, X, raw_score=False, start_iteration=None, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs): pred_leaf=False, pred_contrib=False, **kwargs):
"""Return the predicted value for each sample. """Return the predicted value for each sample.
...@@ -622,6 +622,9 @@ class LGBMModel(_LGBMModelBase): ...@@ -622,6 +622,9 @@ class LGBMModel(_LGBMModelBase):
Input features matrix. Input features matrix.
raw_score : bool, optional (default=False) raw_score : bool, optional (default=False)
Whether to predict raw scores. Whether to predict raw scores.
start_iteration : int or None, optional (default=None)
Start index of the iteration to predict.
If None or <= 0, starts from the first iteration.
num_iteration : int or None, optional (default=None) num_iteration : int or None, optional (default=None)
Limit number of iterations in the prediction. Limit number of iterations in the prediction.
If None, if the best iteration exists, it is used; otherwise, all trees are used. If None, if the best iteration exists, it is used; otherwise, all trees are used.
...@@ -661,7 +664,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -661,7 +664,7 @@ class LGBMModel(_LGBMModelBase):
"match the input. Model n_features_ is %s and " "match the input. Model n_features_ is %s and "
"input n_features is %s " "input n_features is %s "
% (self._n_features, n_features)) % (self._n_features, n_features))
return self._Booster.predict(X, raw_score=raw_score, num_iteration=num_iteration, return self._Booster.predict(X, raw_score=raw_score, start_iteration=start_iteration, num_iteration=num_iteration,
pred_leaf=pred_leaf, pred_contrib=pred_contrib, **kwargs) pred_leaf=pred_leaf, pred_contrib=pred_contrib, **kwargs)
@property @property
...@@ -832,10 +835,10 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -832,10 +835,10 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
fit.__doc__ = LGBMModel.fit.__doc__ fit.__doc__ = LGBMModel.fit.__doc__
def predict(self, X, raw_score=False, num_iteration=None, def predict(self, X, raw_score=False, start_iteration=None, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs): pred_leaf=False, pred_contrib=False, **kwargs):
"""Docstring is inherited from the LGBMModel.""" """Docstring is inherited from the LGBMModel."""
result = self.predict_proba(X, raw_score, num_iteration, result = self.predict_proba(X, raw_score, start_iteration, num_iteration,
pred_leaf, pred_contrib, **kwargs) pred_leaf, pred_contrib, **kwargs)
if callable(self._objective) or raw_score or pred_leaf or pred_contrib: if callable(self._objective) or raw_score or pred_leaf or pred_contrib:
return result return result
...@@ -845,7 +848,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -845,7 +848,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
predict.__doc__ = LGBMModel.predict.__doc__ predict.__doc__ = LGBMModel.predict.__doc__
def predict_proba(self, X, raw_score=False, num_iteration=None, def predict_proba(self, X, raw_score=False, start_iteration=None, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs): pred_leaf=False, pred_contrib=False, **kwargs):
"""Return the predicted probability for each class for each sample. """Return the predicted probability for each class for each sample.
...@@ -855,6 +858,9 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -855,6 +858,9 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
Input features matrix. Input features matrix.
raw_score : bool, optional (default=False) raw_score : bool, optional (default=False)
Whether to predict raw scores. Whether to predict raw scores.
start_iteration : int or None, optional (default=None)
Start index of the iteration to predict.
If None or <= 0, starts from the first iteration.
num_iteration : int or None, optional (default=None) num_iteration : int or None, optional (default=None)
Limit number of iterations in the prediction. Limit number of iterations in the prediction.
If None, if the best iteration exists, it is used; otherwise, all trees are used. If None, if the best iteration exists, it is used; otherwise, all trees are used.
...@@ -884,7 +890,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -884,7 +890,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
X_SHAP_values : array-like of shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects X_SHAP_values : array-like of shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects
If ``pred_contrib=True``, the feature contributions for each sample. If ``pred_contrib=True``, the feature contributions for each sample.
""" """
result = super(LGBMClassifier, self).predict(X, raw_score, num_iteration, result = super(LGBMClassifier, self).predict(X, raw_score, start_iteration, num_iteration,
pred_leaf, pred_contrib, **kwargs) pred_leaf, pred_contrib, **kwargs)
if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib): if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib):
warnings.warn("Cannot compute class probabilities or labels " warnings.warn("Cannot compute class probabilities or labels "
......
...@@ -88,7 +88,7 @@ void Application::LoadData() { ...@@ -88,7 +88,7 @@ void Application::LoadData() {
PredictFunction predict_fun = nullptr; PredictFunction predict_fun = nullptr;
// need to continue training // need to continue training
if (boosting_->NumberOfTotalModel() > 0 && config_.task != TaskType::KRefitTree) { if (boosting_->NumberOfTotalModel() > 0 && config_.task != TaskType::KRefitTree) {
predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, false, -1, -1)); predictor.reset(new Predictor(boosting_.get(), 0, -1, true, false, false, false, -1, -1));
predict_fun = predictor->GetPredictFunction(); predict_fun = predictor->GetPredictFunction();
} }
...@@ -213,7 +213,7 @@ void Application::Train() { ...@@ -213,7 +213,7 @@ void Application::Train() {
void Application::Predict() { void Application::Predict() {
if (config_.task == TaskType::KRefitTree) { if (config_.task == TaskType::KRefitTree) {
// create predictor // create predictor
Predictor predictor(boosting_.get(), -1, false, true, false, false, 1, 1); Predictor predictor(boosting_.get(), 0, -1, false, true, false, false, 1, 1);
predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check); predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check);
TextReader<int> result_reader(config_.output_result.c_str(), false); TextReader<int> result_reader(config_.output_result.c_str(), false);
result_reader.ReadAllLines(); result_reader.ReadAllLines();
...@@ -239,7 +239,7 @@ void Application::Predict() { ...@@ -239,7 +239,7 @@ void Application::Predict() {
Log::Info("Finished RefitTree"); Log::Info("Finished RefitTree");
} else { } else {
// create predictor // create predictor
Predictor predictor(boosting_.get(), config_.num_iteration_predict, config_.predict_raw_score, Predictor predictor(boosting_.get(), config_.start_iteration_predict, config_.num_iteration_predict, config_.predict_raw_score,
config_.predict_leaf_index, config_.predict_contrib, config_.predict_leaf_index, config_.predict_contrib,
config_.pred_early_stop, config_.pred_early_stop_freq, config_.pred_early_stop, config_.pred_early_stop_freq,
config_.pred_early_stop_margin); config_.pred_early_stop_margin);
......
...@@ -31,12 +31,13 @@ class Predictor { ...@@ -31,12 +31,13 @@ class Predictor {
/*! /*!
* \brief Constructor * \brief Constructor
* \param boosting Input boosting model * \param boosting Input boosting model
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of boosting round * \param num_iteration Number of boosting round
* \param is_raw_score True if need to predict result with raw score * \param is_raw_score True if need to predict result with raw score
* \param predict_leaf_index True to output leaf index instead of prediction score * \param predict_leaf_index True to output leaf index instead of prediction score
* \param predict_contrib True to output feature contributions instead of prediction score * \param predict_contrib True to output feature contributions instead of prediction score
*/ */
Predictor(Boosting* boosting, int num_iteration, bool is_raw_score, Predictor(Boosting* boosting, int start_iteration, int num_iteration, bool is_raw_score,
bool predict_leaf_index, bool predict_contrib, bool early_stop, bool predict_leaf_index, bool predict_contrib, bool early_stop,
int early_stop_freq, double early_stop_margin) { int early_stop_freq, double early_stop_margin) {
early_stop_ = CreatePredictionEarlyStopInstance( early_stop_ = CreatePredictionEarlyStopInstance(
...@@ -56,9 +57,9 @@ class Predictor { ...@@ -56,9 +57,9 @@ class Predictor {
} }
} }
boosting->InitPredict(num_iteration, predict_contrib); boosting->InitPredict(start_iteration, num_iteration, predict_contrib);
boosting_ = boosting; boosting_ = boosting;
num_pred_one_row_ = boosting_->NumPredictOneRow( num_pred_one_row_ = boosting_->NumPredictOneRow(start_iteration,
num_iteration, predict_leaf_index, predict_contrib); num_iteration, predict_leaf_index, predict_contrib);
num_feature_ = boosting_->MaxFeatureIdx() + 1; num_feature_ = boosting_->MaxFeatureIdx() + 1;
predict_buf_.resize( predict_buf_.resize(
...@@ -225,6 +226,7 @@ class Predictor { ...@@ -225,6 +226,7 @@ class Predictor {
data_size_t, const std::vector<std::string>& lines) { data_size_t, const std::vector<std::string>& lines) {
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, double>> oneline_features;
std::vector<std::string> result_to_write(lines.size()); std::vector<std::string> result_to_write(lines.size());
Log::Warning("before predict_fun_ is called");
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) firstprivate(oneline_features) #pragma omp parallel for schedule(static) firstprivate(oneline_features)
for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) { for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
...@@ -239,6 +241,7 @@ class Predictor { ...@@ -239,6 +241,7 @@ class Predictor {
result_to_write[i] = str_result; result_to_write[i] = str_result;
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
Log::Warning("after predict_fun_ is called");
OMP_THROW_EX(); OMP_THROW_EX();
for (data_size_t i = 0; i < static_cast<data_size_t>(result_to_write.size()); ++i) { for (data_size_t i = 0; i < static_cast<data_size_t>(result_to_write.size()); ++i) {
writer->Write(result_to_write[i].c_str(), result_to_write[i].size()); writer->Write(result_to_write[i].c_str(), result_to_write[i].size());
......
...@@ -574,7 +574,8 @@ void GBDT::PredictContrib(const double* features, double* output) const { ...@@ -574,7 +574,8 @@ void GBDT::PredictContrib(const double* features, double* output) const {
// set zero // set zero
const int num_features = max_feature_idx_ + 1; const int num_features = max_feature_idx_ + 1;
std::memset(output, 0, sizeof(double) * num_tree_per_iteration_ * (num_features + 1)); std::memset(output, 0, sizeof(double) * num_tree_per_iteration_ * (num_features + 1));
for (int i = 0; i < num_iteration_for_pred_; ++i) { const int end_iteration_for_pred = start_iteration_for_pred_ + num_iteration_for_pred_;
for (int i = start_iteration_for_pred_; i < end_iteration_for_pred; ++i) {
// predict all the trees for one iteration // predict all the trees for one iteration
for (int k = 0; k < num_tree_per_iteration_; ++k) { for (int k = 0; k < num_tree_per_iteration_; ++k) {
models_[i * num_tree_per_iteration_ + k]->PredictContrib(features, num_features, output + k*(num_features + 1)); models_[i * num_tree_per_iteration_ + k]->PredictContrib(features, num_features, output + k*(num_features + 1));
...@@ -585,7 +586,8 @@ void GBDT::PredictContrib(const double* features, double* output) const { ...@@ -585,7 +586,8 @@ void GBDT::PredictContrib(const double* features, double* output) const {
void GBDT::PredictContribByMap(const std::unordered_map<int, double>& features, void GBDT::PredictContribByMap(const std::unordered_map<int, double>& features,
std::vector<std::unordered_map<int, double>>* output) const { std::vector<std::unordered_map<int, double>>* output) const {
const int num_features = max_feature_idx_ + 1; const int num_features = max_feature_idx_ + 1;
for (int i = 0; i < num_iteration_for_pred_; ++i) { const int end_iteration_for_pred = start_iteration_for_pred_ + num_iteration_for_pred_;
for (int i = start_iteration_for_pred_; i < end_iteration_for_pred; ++i) {
// predict all the trees for one iteration // predict all the trees for one iteration
for (int k = 0; k < num_tree_per_iteration_; ++k) { for (int k = 0; k < num_tree_per_iteration_; ++k) {
models_[i * num_tree_per_iteration_ + k]->PredictContribByMap(features, num_features, &((*output)[k])); models_[i * num_tree_per_iteration_ + k]->PredictContribByMap(features, num_features, &((*output)[k]));
......
...@@ -204,19 +204,22 @@ class GBDT : public GBDTBase { ...@@ -204,19 +204,22 @@ class GBDT : public GBDTBase {
/*! /*!
* \brief Get number of prediction for one data * \brief Get number of prediction for one data
* \param start_iteration Start index of the iteration to predict
* \param num_iteration number of used iterations * \param num_iteration number of used iterations
* \param is_pred_leaf True if predicting leaf index * \param is_pred_leaf True if predicting leaf index
* \param is_pred_contrib True if predicting feature contribution * \param is_pred_contrib True if predicting feature contribution
* \return number of prediction * \return number of prediction
*/ */
inline int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const override { inline int NumPredictOneRow(int start_iteration, int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const override {
int num_pred_in_one_row = num_class_; int num_pred_in_one_row = num_class_;
if (is_pred_leaf) { if (is_pred_leaf) {
int max_iteration = GetCurrentIteration(); int max_iteration = GetCurrentIteration();
start_iteration = std::max(start_iteration, 0);
start_iteration = std::min(start_iteration, max_iteration);
if (num_iteration > 0) { if (num_iteration > 0) {
num_pred_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration)); num_pred_in_one_row *= static_cast<int>(std::min(max_iteration - start_iteration, num_iteration));
} else { } else {
num_pred_in_one_row *= max_iteration; num_pred_in_one_row *= (max_iteration - start_iteration);
} }
} else if (is_pred_contrib) { } else if (is_pred_contrib) {
num_pred_in_one_row = num_tree_per_iteration_ * (max_feature_idx_ + 2); // +1 for 0-based indexing, +1 for baseline num_pred_in_one_row = num_tree_per_iteration_ * (max_feature_idx_ + 2); // +1 for 0-based indexing, +1 for baseline
...@@ -352,11 +355,16 @@ class GBDT : public GBDTBase { ...@@ -352,11 +355,16 @@ class GBDT : public GBDTBase {
*/ */
inline int NumberOfClasses() const override { return num_class_; } inline int NumberOfClasses() const override { return num_class_; }
inline void InitPredict(int num_iteration, bool is_pred_contrib) override { inline void InitPredict(int start_iteration, int num_iteration, bool is_pred_contrib) override {
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_; num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
start_iteration = std::max(start_iteration, 0);
start_iteration = std::min(start_iteration, num_iteration_for_pred_);
if (num_iteration > 0) { if (num_iteration > 0) {
num_iteration_for_pred_ = std::min(num_iteration, num_iteration_for_pred_); num_iteration_for_pred_ = std::min(num_iteration, num_iteration_for_pred_ - start_iteration);
} else {
num_iteration_for_pred_ = num_iteration_for_pred_ - start_iteration;
} }
start_iteration_for_pred_ = start_iteration;
if (is_pred_contrib) { if (is_pred_contrib) {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < static_cast<int>(models_.size()); ++i) { for (int i = 0; i < static_cast<int>(models_.size()); ++i) {
...@@ -489,6 +497,8 @@ class GBDT : public GBDTBase { ...@@ -489,6 +497,8 @@ class GBDT : public GBDTBase {
data_size_t label_idx_; data_size_t label_idx_;
/*! \brief number of used model */ /*! \brief number of used model */
int num_iteration_for_pred_; int num_iteration_for_pred_;
/*! \brief Start iteration of used model */
int start_iteration_for_pred_;
/*! \brief Shrinkage rate for one iteration */ /*! \brief Shrinkage rate for one iteration */
double shrinkage_rate_; double shrinkage_rate_;
/*! \brief Number of loaded initial models */ /*! \brief Number of loaded initial models */
......
...@@ -14,7 +14,8 @@ void GBDT::PredictRaw(const double* features, double* output, const PredictionEa ...@@ -14,7 +14,8 @@ void GBDT::PredictRaw(const double* features, double* output, const PredictionEa
int early_stop_round_counter = 0; int early_stop_round_counter = 0;
// set zero // set zero
std::memset(output, 0, sizeof(double) * num_tree_per_iteration_); std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);
for (int i = 0; i < num_iteration_for_pred_; ++i) { const int end_iteration_for_pred = start_iteration_for_pred_ + num_iteration_for_pred_;
for (int i = start_iteration_for_pred_; i < end_iteration_for_pred; ++i) {
// predict all the trees for one iteration // predict all the trees for one iteration
for (int k = 0; k < num_tree_per_iteration_; ++k) { for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] += models_[i * num_tree_per_iteration_ + k]->Predict(features); output[k] += models_[i * num_tree_per_iteration_ + k]->Predict(features);
...@@ -34,7 +35,8 @@ void GBDT::PredictRawByMap(const std::unordered_map<int, double>& features, doub ...@@ -34,7 +35,8 @@ void GBDT::PredictRawByMap(const std::unordered_map<int, double>& features, doub
int early_stop_round_counter = 0; int early_stop_round_counter = 0;
// set zero // set zero
std::memset(output, 0, sizeof(double) * num_tree_per_iteration_); std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);
for (int i = 0; i < num_iteration_for_pred_; ++i) { const int end_iteration_for_pred = start_iteration_for_pred_ + num_iteration_for_pred_;
for (int i = start_iteration_for_pred_; i < end_iteration_for_pred; ++i) {
// predict all the trees for one iteration // predict all the trees for one iteration
for (int k = 0; k < num_tree_per_iteration_; ++k) { for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] += models_[i * num_tree_per_iteration_ + k]->PredictByMap(features); output[k] += models_[i * num_tree_per_iteration_ + k]->PredictByMap(features);
...@@ -75,16 +77,20 @@ void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double* ...@@ -75,16 +77,20 @@ void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double*
} }
void GBDT::PredictLeafIndex(const double* features, double* output) const { void GBDT::PredictLeafIndex(const double* features, double* output) const {
int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_; int start_tree = start_iteration_for_pred_ * num_tree_per_iteration_;
for (int i = 0; i < total_tree; ++i) { int num_trees = num_iteration_for_pred_ * num_tree_per_iteration_;
output[i] = models_[i]->PredictLeafIndex(features); const auto* models_ptr = models_.data() + start_tree;
for (int i = 0; i < num_trees; ++i) {
output[i] = models_ptr[i]->PredictLeafIndex(features);
} }
} }
void GBDT::PredictLeafIndexByMap(const std::unordered_map<int, double>& features, double* output) const { void GBDT::PredictLeafIndexByMap(const std::unordered_map<int, double>& features, double* output) const {
int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_; int start_tree = start_iteration_for_pred_ * num_tree_per_iteration_;
for (int i = 0; i < total_tree; ++i) { int num_trees = num_iteration_for_pred_ * num_tree_per_iteration_;
output[i] = models_[i]->PredictLeafIndexByMap(features); const auto* models_ptr = models_.data() + start_tree;
for (int i = 0; i < num_trees; ++i) {
output[i] = models_ptr[i]->PredictLeafIndexByMap(features);
} }
} }
......
...@@ -62,7 +62,7 @@ class SingleRowPredictor { ...@@ -62,7 +62,7 @@ class SingleRowPredictor {
PredictFunction predict_function; PredictFunction predict_function;
int64_t num_pred_in_one_row; int64_t num_pred_in_one_row;
SingleRowPredictor(int predict_type, Boosting* boosting, const Config& config, int iter) { SingleRowPredictor(int predict_type, Boosting* boosting, const Config& config, int start_iter, int num_iter) {
bool is_predict_leaf = false; bool is_predict_leaf = false;
bool is_raw_score = false; bool is_raw_score = false;
bool predict_contrib = false; bool predict_contrib = false;
...@@ -78,10 +78,10 @@ class SingleRowPredictor { ...@@ -78,10 +78,10 @@ class SingleRowPredictor {
early_stop_ = config.pred_early_stop; early_stop_ = config.pred_early_stop;
early_stop_freq_ = config.pred_early_stop_freq; early_stop_freq_ = config.pred_early_stop_freq;
early_stop_margin_ = config.pred_early_stop_margin; early_stop_margin_ = config.pred_early_stop_margin;
iter_ = iter; iter_ = num_iter;
predictor_.reset(new Predictor(boosting, iter_, is_raw_score, is_predict_leaf, predict_contrib, predictor_.reset(new Predictor(boosting, start_iter, iter_, is_raw_score, is_predict_leaf, predict_contrib,
early_stop_, early_stop_freq_, early_stop_margin_)); early_stop_, early_stop_freq_, early_stop_margin_));
num_pred_in_one_row = boosting->NumPredictOneRow(iter_, is_predict_leaf, predict_contrib); num_pred_in_one_row = boosting->NumPredictOneRow(start_iter, iter_, is_predict_leaf, predict_contrib);
predict_function = predictor_->GetPredictFunction(); predict_function = predictor_->GetPredictFunction();
num_total_model_ = boosting->NumberOfTotalModel(); num_total_model_ = boosting->NumberOfTotalModel();
} }
...@@ -369,12 +369,12 @@ class Booster { ...@@ -369,12 +369,12 @@ class Booster {
boosting_->RollbackOneIter(); boosting_->RollbackOneIter();
} }
void SetSingleRowPredictor(int num_iteration, int predict_type, const Config& config) { void SetSingleRowPredictor(int start_iteration, int num_iteration, int predict_type, const Config& config) {
UNIQUE_LOCK(mutex_) UNIQUE_LOCK(mutex_)
if (single_row_predictor_[predict_type].get() == nullptr || if (single_row_predictor_[predict_type].get() == nullptr ||
!single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) { !single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) {
single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(), single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(),
config, num_iteration)); config, start_iteration, num_iteration));
} }
} }
...@@ -395,7 +395,7 @@ class Booster { ...@@ -395,7 +395,7 @@ class Booster {
*out_len = single_row_predictor->num_pred_in_one_row; *out_len = single_row_predictor->num_pred_in_one_row;
} }
Predictor CreatePredictor(int num_iteration, int predict_type, int ncol, const Config& config) const { Predictor CreatePredictor(int start_iteration, int num_iteration, int predict_type, int ncol, const Config& config) const {
if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) { if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \ Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \
"You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1); "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
...@@ -413,17 +413,17 @@ class Booster { ...@@ -413,17 +413,17 @@ class Booster {
is_raw_score = false; is_raw_score = false;
} }
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib, Predictor predictor(boosting_.get(), start_iteration, num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin); config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
return predictor; return predictor;
} }
void Predict(int num_iteration, int predict_type, int nrow, int ncol, void Predict(int start_iteration, int num_iteration, int predict_type, int nrow, int ncol,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun, std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const Config& config, const Config& config,
double* out_result, int64_t* out_len) const { double* out_result, int64_t* out_len) const {
SHARED_LOCK(mutex_); SHARED_LOCK(mutex_);
auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config); auto predictor = CreatePredictor(start_iteration, num_iteration, predict_type, ncol, config);
bool is_predict_leaf = false; bool is_predict_leaf = false;
bool predict_contrib = false; bool predict_contrib = false;
if (predict_type == C_API_PREDICT_LEAF_INDEX) { if (predict_type == C_API_PREDICT_LEAF_INDEX) {
...@@ -431,7 +431,7 @@ class Booster { ...@@ -431,7 +431,7 @@ class Booster {
} else if (predict_type == C_API_PREDICT_CONTRIB) { } else if (predict_type == C_API_PREDICT_CONTRIB) {
predict_contrib = true; predict_contrib = true;
} }
int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib); int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(start_iteration, num_iteration, is_predict_leaf, predict_contrib);
auto pred_fun = predictor.GetPredictFunction(); auto pred_fun = predictor.GetPredictFunction();
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
...@@ -446,13 +446,13 @@ class Booster { ...@@ -446,13 +446,13 @@ class Booster {
*out_len = num_pred_in_one_row * nrow; *out_len = num_pred_in_one_row * nrow;
} }
void PredictSparse(int num_iteration, int predict_type, int64_t nrow, int ncol, void PredictSparse(int start_iteration, int num_iteration, int predict_type, int64_t nrow, int ncol,
std::function<std::vector<std::pair<int, double>>(int64_t row_idx)> get_row_fun, std::function<std::vector<std::pair<int, double>>(int64_t row_idx)> get_row_fun,
const Config& config, int64_t* out_elements_size, const Config& config, int64_t* out_elements_size,
std::vector<std::vector<std::unordered_map<int, double>>>* agg_ptr, std::vector<std::vector<std::unordered_map<int, double>>>* agg_ptr,
int32_t** out_indices, void** out_data, int data_type, int32_t** out_indices, void** out_data, int data_type,
bool* is_data_float32_ptr, int num_matrices) const { bool* is_data_float32_ptr, int num_matrices) const {
auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config); auto predictor = CreatePredictor(start_iteration, num_iteration, predict_type, ncol, config);
auto pred_sparse_fun = predictor.GetPredictSparseFunction(); auto pred_sparse_fun = predictor.GetPredictSparseFunction();
std::vector<std::vector<std::unordered_map<int, double>>>& agg = *agg_ptr; std::vector<std::vector<std::unordered_map<int, double>>>& agg = *agg_ptr;
OMP_INIT_EX(); OMP_INIT_EX();
...@@ -488,7 +488,7 @@ class Booster { ...@@ -488,7 +488,7 @@ class Booster {
*out_indices = new int32_t[elements_size]; *out_indices = new int32_t[elements_size];
} }
void PredictSparseCSR(int num_iteration, int predict_type, int64_t nrow, int ncol, void PredictSparseCSR(int start_iteration, int num_iteration, int predict_type, int64_t nrow, int ncol,
std::function<std::vector<std::pair<int, double>>(int64_t row_idx)> get_row_fun, std::function<std::vector<std::pair<int, double>>(int64_t row_idx)> get_row_fun,
const Config& config, const Config& config,
int64_t* out_len, void** out_indptr, int indptr_type, int64_t* out_len, void** out_indptr, int indptr_type,
...@@ -511,7 +511,7 @@ class Booster { ...@@ -511,7 +511,7 @@ class Booster {
// aggregated per row feature contribution results // aggregated per row feature contribution results
std::vector<std::vector<std::unordered_map<int, double>>> agg(nrow); std::vector<std::vector<std::unordered_map<int, double>>> agg(nrow);
int64_t elements_size = 0; int64_t elements_size = 0;
PredictSparse(num_iteration, predict_type, nrow, ncol, get_row_fun, config, &elements_size, &agg, PredictSparse(start_iteration, num_iteration, predict_type, nrow, ncol, get_row_fun, config, &elements_size, &agg,
out_indices, out_data, data_type, &is_data_float32, num_matrices); out_indices, out_data, data_type, &is_data_float32, num_matrices);
std::vector<int> row_sizes(num_matrices * nrow); std::vector<int> row_sizes(num_matrices * nrow);
std::vector<int64_t> row_matrix_offsets(num_matrices * nrow); std::vector<int64_t> row_matrix_offsets(num_matrices * nrow);
...@@ -572,7 +572,7 @@ class Booster { ...@@ -572,7 +572,7 @@ class Booster {
out_len[1] = indptr_size; out_len[1] = indptr_size;
} }
void PredictSparseCSC(int num_iteration, int predict_type, int64_t nrow, int ncol, void PredictSparseCSC(int start_iteration, int num_iteration, int predict_type, int64_t nrow, int ncol,
std::function<std::vector<std::pair<int, double>>(int64_t row_idx)> get_row_fun, std::function<std::vector<std::pair<int, double>>(int64_t row_idx)> get_row_fun,
const Config& config, const Config& config,
int64_t* out_len, void** out_col_ptr, int col_ptr_type, int64_t* out_len, void** out_col_ptr, int col_ptr_type,
...@@ -580,7 +580,7 @@ class Booster { ...@@ -580,7 +580,7 @@ class Booster {
SHARED_LOCK(mutex_); SHARED_LOCK(mutex_);
// Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices) // Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices)
int num_matrices = boosting_->NumModelPerIteration(); int num_matrices = boosting_->NumModelPerIteration();
auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config); auto predictor = CreatePredictor(start_iteration, num_iteration, predict_type, ncol, config);
auto pred_sparse_fun = predictor.GetPredictSparseFunction(); auto pred_sparse_fun = predictor.GetPredictSparseFunction();
bool is_col_ptr_int32 = false; bool is_col_ptr_int32 = false;
bool is_data_float32 = false; bool is_data_float32 = false;
...@@ -598,7 +598,7 @@ class Booster { ...@@ -598,7 +598,7 @@ class Booster {
// aggregated per row feature contribution results // aggregated per row feature contribution results
std::vector<std::vector<std::unordered_map<int, double>>> agg(nrow); std::vector<std::vector<std::unordered_map<int, double>>> agg(nrow);
int64_t elements_size = 0; int64_t elements_size = 0;
PredictSparse(num_iteration, predict_type, nrow, ncol, get_row_fun, config, &elements_size, &agg, PredictSparse(start_iteration, num_iteration, predict_type, nrow, ncol, get_row_fun, config, &elements_size, &agg,
out_indices, out_data, data_type, &is_data_float32, num_matrices); out_indices, out_data, data_type, &is_data_float32, num_matrices);
// calculate number of elements per column to construct // calculate number of elements per column to construct
// the CSC matrix with random access // the CSC matrix with random access
...@@ -676,7 +676,7 @@ class Booster { ...@@ -676,7 +676,7 @@ class Booster {
out_len[1] = col_ptr_size; out_len[1] = col_ptr_size;
} }
void Predict(int num_iteration, int predict_type, const char* data_filename, void Predict(int start_iteration, int num_iteration, int predict_type, const char* data_filename,
int data_has_header, const Config& config, int data_has_header, const Config& config,
const char* result_filename) const { const char* result_filename) const {
SHARED_LOCK(mutex_) SHARED_LOCK(mutex_)
...@@ -692,7 +692,7 @@ class Booster { ...@@ -692,7 +692,7 @@ class Booster {
} else { } else {
is_raw_score = false; is_raw_score = false;
} }
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib, Predictor predictor(boosting_.get(), start_iteration, num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin); config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
bool bool_data_has_header = data_has_header > 0 ? true : false; bool bool_data_has_header = data_has_header > 0 ? true : false;
predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check); predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check);
...@@ -1728,6 +1728,7 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -1728,6 +1728,7 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* data_filename, const char* data_filename,
int data_has_header, int data_has_header,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
const char* result_filename) { const char* result_filename) {
...@@ -1739,7 +1740,7 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -1739,7 +1740,7 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle,
omp_set_num_threads(config.num_threads); omp_set_num_threads(config.num_threads);
} }
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header, ref_booster->Predict(start_iteration, num_iteration, predict_type, data_filename, data_has_header,
config, result_filename); config, result_filename);
API_END(); API_END();
} }
...@@ -1747,11 +1748,12 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -1747,11 +1748,12 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle,
int LGBM_BoosterCalcNumPredict(BoosterHandle handle, int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
int num_row, int num_row,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
int64_t* out_len) { int64_t* out_len) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow( *out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow(start_iteration,
num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB); num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB);
API_END(); API_END();
} }
...@@ -1798,6 +1800,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -1798,6 +1800,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t nelem, int64_t nelem,
int64_t num_col, int64_t num_col,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
...@@ -1817,7 +1820,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -1817,7 +1820,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int nrow = static_cast<int>(nindptr - 1); int nrow = static_cast<int>(nindptr - 1);
ref_booster->Predict(num_iteration, predict_type, nrow, static_cast<int>(num_col), get_row_fun, ref_booster->Predict(start_iteration, num_iteration, predict_type, nrow, static_cast<int>(num_col), get_row_fun,
config, out_result, out_len); config, out_result, out_len);
API_END(); API_END();
} }
...@@ -1832,6 +1835,7 @@ int LGBM_BoosterPredictSparseOutput(BoosterHandle handle, ...@@ -1832,6 +1835,7 @@ int LGBM_BoosterPredictSparseOutput(BoosterHandle handle,
int64_t nelem, int64_t nelem,
int64_t num_col_or_row, int64_t num_col_or_row,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int matrix_type, int matrix_type,
...@@ -1855,7 +1859,7 @@ int LGBM_BoosterPredictSparseOutput(BoosterHandle handle, ...@@ -1855,7 +1859,7 @@ int LGBM_BoosterPredictSparseOutput(BoosterHandle handle,
} }
auto get_row_fun = RowFunctionFromCSR<int64_t>(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR<int64_t>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int64_t nrow = nindptr - 1; int64_t nrow = nindptr - 1;
ref_booster->PredictSparseCSR(num_iteration, predict_type, nrow, static_cast<int>(num_col_or_row), get_row_fun, ref_booster->PredictSparseCSR(start_iteration, num_iteration, predict_type, nrow, static_cast<int>(num_col_or_row), get_row_fun,
config, out_len, out_indptr, indptr_type, out_indices, out_data, data_type); config, out_len, out_indptr, indptr_type, out_indices, out_data, data_type);
} else if (matrix_type == C_API_MATRIX_TYPE_CSC) { } else if (matrix_type == C_API_MATRIX_TYPE_CSC) {
int num_threads = OMP_NUM_THREADS(); int num_threads = OMP_NUM_THREADS();
...@@ -1879,7 +1883,7 @@ int LGBM_BoosterPredictSparseOutput(BoosterHandle handle, ...@@ -1879,7 +1883,7 @@ int LGBM_BoosterPredictSparseOutput(BoosterHandle handle,
} }
return one_row; return one_row;
}; };
ref_booster->PredictSparseCSC(num_iteration, predict_type, num_col_or_row, ncol, get_row_fun, config, ref_booster->PredictSparseCSC(start_iteration, num_iteration, predict_type, num_col_or_row, ncol, get_row_fun, config,
out_len, out_indptr, indptr_type, out_indices, out_data, data_type); out_len, out_indptr, indptr_type, out_indices, out_data, data_type);
} else { } else {
Log::Fatal("Unknown matrix type in LGBM_BoosterPredictSparseOutput"); Log::Fatal("Unknown matrix type in LGBM_BoosterPredictSparseOutput");
...@@ -1917,6 +1921,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, ...@@ -1917,6 +1921,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
int64_t nelem, int64_t nelem,
int64_t num_col, int64_t num_col,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
...@@ -1935,13 +1940,14 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, ...@@ -1935,13 +1940,14 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
} }
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
ref_booster->SetSingleRowPredictor(num_iteration, predict_type, config); ref_booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, config);
ref_booster->PredictSingleRow(predict_type, static_cast<int32_t>(num_col), get_row_fun, config, out_result, out_len); ref_booster->PredictSingleRow(predict_type, static_cast<int32_t>(num_col), get_row_fun, config, out_result, out_len);
API_END(); API_END();
} }
int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle, int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
const int predict_type, const int predict_type,
const int start_iteration,
const int num_iteration, const int num_iteration,
const int data_type, const int data_type,
const int64_t num_col, const int64_t num_col,
...@@ -1965,7 +1971,7 @@ int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle, ...@@ -1965,7 +1971,7 @@ int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
omp_set_num_threads(fastConfig_ptr->config.num_threads); omp_set_num_threads(fastConfig_ptr->config.num_threads);
} }
fastConfig_ptr->booster->SetSingleRowPredictor(num_iteration, predict_type, fastConfig_ptr->config); fastConfig_ptr->booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, fastConfig_ptr->config);
*out_fastConfig = fastConfig_ptr.release(); *out_fastConfig = fastConfig_ptr.release();
API_END(); API_END();
...@@ -1999,6 +2005,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -1999,6 +2005,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
int64_t nelem, int64_t nelem,
int64_t num_row, int64_t num_row,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
...@@ -2032,7 +2039,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -2032,7 +2039,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
} }
return one_row; return one_row;
}; };
ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), ncol, get_row_fun, config, ref_booster->Predict(start_iteration, num_iteration, predict_type, static_cast<int>(num_row), ncol, get_row_fun, config,
out_result, out_len); out_result, out_len);
API_END(); API_END();
} }
...@@ -2044,6 +2051,7 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -2044,6 +2051,7 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
int32_t ncol, int32_t ncol,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
...@@ -2057,7 +2065,7 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -2057,7 +2065,7 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
} }
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, ref_booster->Predict(start_iteration, num_iteration, predict_type, nrow, ncol, get_row_fun,
config, out_result, out_len); config, out_result, out_len);
API_END(); API_END();
} }
...@@ -2068,6 +2076,7 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, ...@@ -2068,6 +2076,7 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
int32_t ncol, int32_t ncol,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
...@@ -2081,13 +2090,14 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, ...@@ -2081,13 +2090,14 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
} }
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major); auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
ref_booster->SetSingleRowPredictor(num_iteration, predict_type, config); ref_booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, config);
ref_booster->PredictSingleRow(predict_type, ncol, get_row_fun, config, out_result, out_len); ref_booster->PredictSingleRow(predict_type, ncol, get_row_fun, config, out_result, out_len);
API_END(); API_END();
} }
int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle, int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
const int predict_type, const int predict_type,
const int start_iteration,
const int num_iteration, const int num_iteration,
const int data_type, const int data_type,
const int32_t ncol, const int32_t ncol,
...@@ -2105,7 +2115,7 @@ int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle, ...@@ -2105,7 +2115,7 @@ int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
omp_set_num_threads(fastConfig_ptr->config.num_threads); omp_set_num_threads(fastConfig_ptr->config.num_threads);
} }
fastConfig_ptr->booster->SetSingleRowPredictor(num_iteration, predict_type, fastConfig_ptr->config); fastConfig_ptr->booster->SetSingleRowPredictor(start_iteration, num_iteration, predict_type, fastConfig_ptr->config);
*out_fastConfig = fastConfig_ptr.release(); *out_fastConfig = fastConfig_ptr.release();
API_END(); API_END();
...@@ -2132,6 +2142,7 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle, ...@@ -2132,6 +2142,7 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle,
int32_t nrow, int32_t nrow,
int32_t ncol, int32_t ncol,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
...@@ -2145,7 +2156,7 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle, ...@@ -2145,7 +2156,7 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle,
} }
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type); auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len); ref_booster->Predict(start_iteration, num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len);
API_END(); API_END();
} }
......
...@@ -256,6 +256,7 @@ const std::unordered_set<std::string>& Config::parameter_set() { ...@@ -256,6 +256,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"categorical_feature", "categorical_feature",
"forcedbins_filename", "forcedbins_filename",
"save_binary", "save_binary",
"start_iteration_predict",
"num_iteration_predict", "num_iteration_predict",
"predict_raw_score", "predict_raw_score",
"predict_leaf_index", "predict_leaf_index",
...@@ -513,6 +514,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -513,6 +514,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetBool(params, "save_binary", &save_binary); GetBool(params, "save_binary", &save_binary);
GetInt(params, "start_iteration_predict", &start_iteration_predict);
GetInt(params, "num_iteration_predict", &num_iteration_predict); GetInt(params, "num_iteration_predict", &num_iteration_predict);
GetBool(params, "predict_raw_score", &predict_raw_score); GetBool(params, "predict_raw_score", &predict_raw_score);
......
...@@ -83,13 +83,14 @@ ...@@ -83,13 +83,14 @@
int ncol, int ncol,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
double* data0 = (double*)jenv->GetPrimitiveArrayCritical(data, 0); double* data0 = (double*)jenv->GetPrimitiveArrayCritical(data, 0);
int ret = LGBM_BoosterPredictForMatSingleRow(handle, data0, data_type, ncol, is_row_major, predict_type, int ret = LGBM_BoosterPredictForMatSingleRow(handle, data0, data_type, ncol, is_row_major, predict_type, start_iteration,
num_iteration, parameter, out_len, out_result); num_iteration, parameter, out_len, out_result);
jenv->ReleasePrimitiveArrayCritical(data, data0, JNI_ABORT); jenv->ReleasePrimitiveArrayCritical(data, data0, JNI_ABORT);
...@@ -130,6 +131,7 @@ ...@@ -130,6 +131,7 @@
int64_t nelem, int64_t nelem,
int64_t num_col, int64_t num_col,
int predict_type, int predict_type,
int start_iteration,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
...@@ -147,7 +149,7 @@ ...@@ -147,7 +149,7 @@
int32_t ind[2] = { 0, numNonZeros }; int32_t ind[2] = { 0, numNonZeros };
int ret = LGBM_BoosterPredictForCSRSingleRow(handle, ind, indptr_type, indices0, values0, data_type, 2, int ret = LGBM_BoosterPredictForCSRSingleRow(handle, ind, indptr_type, indices0, values0, data_type, 2,
nelem, num_col, predict_type, num_iteration, parameter, out_len, out_result); nelem, num_col, predict_type, start_iteration, num_iteration, parameter, out_len, out_result);
jenv->ReleasePrimitiveArrayCritical(values, values0, JNI_ABORT); jenv->ReleasePrimitiveArrayCritical(values, values0, JNI_ABORT);
jenv->ReleasePrimitiveArrayCritical(indices, indices0, JNI_ABORT); jenv->ReleasePrimitiveArrayCritical(indices, indices0, JNI_ABORT);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment