Commit cb9fabda authored by Eric Graves's avatar Eric Graves Committed by Guolin Ke
Browse files

added code to expose c_api_pred_contrib in the R package (#1259)

* added code to expose c_api_pred_contrib in the R package

* removed Rprintf

* reverted to previous version of install.libs.R
parent 48ff86e6
...@@ -380,3 +380,4 @@ lightgbm.model ...@@ -380,3 +380,4 @@ lightgbm.model
# duplicate version file # duplicate version file
python-package/lightgbm/VERSION.txt python-package/lightgbm/VERSION.txt
.Rproj.user
...@@ -615,6 +615,7 @@ Booster <- R6Class( ...@@ -615,6 +615,7 @@ Booster <- R6Class(
#' sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} for #' sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} for
#' logistic regression would result in predictions for log-odds instead of probabilities. #' logistic regression would result in predictions for log-odds instead of probabilities.
#' @param predleaf whether predict leaf index instead. #' @param predleaf whether predict leaf index instead.
#' @param predcontrib return per-feature contributions for each record.
#' @param header only used for prediction for text file. True if text file has header #' @param header only used for prediction for text file. True if text file has header
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several #' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
#' prediction outputs per case. #' prediction outputs per case.
...@@ -655,6 +656,7 @@ predict.lgb.Booster <- function(object, data, ...@@ -655,6 +656,7 @@ predict.lgb.Booster <- function(object, data,
num_iteration = NULL, num_iteration = NULL,
rawscore = FALSE, rawscore = FALSE,
predleaf = FALSE, predleaf = FALSE,
predcontrib = FALSE,
header = FALSE, header = FALSE,
reshape = FALSE, ...) { reshape = FALSE, ...) {
...@@ -668,6 +670,7 @@ predict.lgb.Booster <- function(object, data, ...@@ -668,6 +670,7 @@ predict.lgb.Booster <- function(object, data,
num_iteration, num_iteration,
rawscore, rawscore,
predleaf, predleaf,
predcontrib,
header, header,
reshape, ...) reshape, ...)
} }
......
...@@ -63,6 +63,7 @@ Predictor <- R6Class( ...@@ -63,6 +63,7 @@ Predictor <- R6Class(
num_iteration = NULL, num_iteration = NULL,
rawscore = FALSE, rawscore = FALSE,
predleaf = FALSE, predleaf = FALSE,
predcontrib = FALSE,
header = FALSE, header = FALSE,
reshape = FALSE) { reshape = FALSE) {
...@@ -86,6 +87,7 @@ Predictor <- R6Class( ...@@ -86,6 +87,7 @@ Predictor <- R6Class(
as.integer(header), as.integer(header),
as.integer(rawscore), as.integer(rawscore),
as.integer(predleaf), as.integer(predleaf),
as.integer(predcontrib),
as.integer(num_iteration), as.integer(num_iteration),
private$params, private$params,
lgb.c_str(tmp_filename)) lgb.c_str(tmp_filename))
...@@ -99,6 +101,7 @@ Predictor <- R6Class( ...@@ -99,6 +101,7 @@ Predictor <- R6Class(
# Not a file, we need to predict from R object # Not a file, we need to predict from R object
num_row <- nrow(data) num_row <- nrow(data)
npred <- 0L npred <- 0L
# Check number of predictions to do # Check number of predictions to do
...@@ -108,6 +111,7 @@ Predictor <- R6Class( ...@@ -108,6 +111,7 @@ Predictor <- R6Class(
as.integer(num_row), as.integer(num_row),
as.integer(rawscore), as.integer(rawscore),
as.integer(predleaf), as.integer(predleaf),
as.integer(predcontrib),
as.integer(num_iteration)) as.integer(num_iteration))
# Pre-allocate empty vector # Pre-allocate empty vector
...@@ -123,6 +127,7 @@ Predictor <- R6Class( ...@@ -123,6 +127,7 @@ Predictor <- R6Class(
as.integer(ncol(data)), as.integer(ncol(data)),
as.integer(rawscore), as.integer(rawscore),
as.integer(predleaf), as.integer(predleaf),
as.integer(predcontrib),
as.integer(num_iteration), as.integer(num_iteration),
private$params) private$params)
...@@ -142,6 +147,7 @@ Predictor <- R6Class( ...@@ -142,6 +147,7 @@ Predictor <- R6Class(
nrow(data), nrow(data),
as.integer(rawscore), as.integer(rawscore),
as.integer(predleaf), as.integer(predleaf),
as.integer(predcontrib),
as.integer(num_iteration), as.integer(num_iteration),
private$params) private$params)
...@@ -165,7 +171,7 @@ Predictor <- R6Class( ...@@ -165,7 +171,7 @@ Predictor <- R6Class(
# Data reshaping # Data reshaping
if (predleaf) { if (predleaf | predcontrib) {
# Predict leaves only, reshaping is mandatory # Predict leaves only, reshaping is mandatory
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE) preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
......
...@@ -388,6 +388,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle, ...@@ -388,6 +388,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
LGBM_SE data_has_header, LGBM_SE data_has_header,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE result_filename, LGBM_SE result_filename,
...@@ -407,6 +408,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle, ...@@ -407,6 +408,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
LGBM_SE num_row, LGBM_SE num_row,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE out_len, LGBM_SE out_len,
LGBM_SE call_state); LGBM_SE call_state);
...@@ -438,6 +440,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle, ...@@ -438,6 +440,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
LGBM_SE num_row, LGBM_SE num_row,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE out_result, LGBM_SE out_result,
...@@ -464,6 +467,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle, ...@@ -464,6 +467,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE ncol, LGBM_SE ncol,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE out_result, LGBM_SE out_result,
......
...@@ -479,7 +479,7 @@ LGBM_SE LGBM_BoosterGetPredict_R(LGBM_SE handle, ...@@ -479,7 +479,7 @@ LGBM_SE LGBM_BoosterGetPredict_R(LGBM_SE handle,
R_API_END(); R_API_END();
} }
int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx) { int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx, LGBM_SE is_predcontrib) {
int pred_type = C_API_PREDICT_NORMAL; int pred_type = C_API_PREDICT_NORMAL;
if (R_AS_INT(is_rawscore)) { if (R_AS_INT(is_rawscore)) {
pred_type = C_API_PREDICT_RAW_SCORE; pred_type = C_API_PREDICT_RAW_SCORE;
...@@ -487,6 +487,9 @@ int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx) { ...@@ -487,6 +487,9 @@ int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx) {
if (R_AS_INT(is_leafidx)) { if (R_AS_INT(is_leafidx)) {
pred_type = C_API_PREDICT_LEAF_INDEX; pred_type = C_API_PREDICT_LEAF_INDEX;
} }
if (R_AS_INT(is_predcontrib)) {
pred_type = C_API_PREDICT_CONTRIB;
}
return pred_type; return pred_type;
} }
...@@ -495,12 +498,13 @@ LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle, ...@@ -495,12 +498,13 @@ LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
LGBM_SE data_has_header, LGBM_SE data_has_header,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE result_filename, LGBM_SE result_filename,
LGBM_SE call_state) { LGBM_SE call_state) {
R_API_BEGIN(); R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename), CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
R_CHAR_PTR(result_filename))); R_CHAR_PTR(result_filename)));
...@@ -511,11 +515,12 @@ LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle, ...@@ -511,11 +515,12 @@ LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
LGBM_SE num_row, LGBM_SE num_row,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE out_len, LGBM_SE out_len,
LGBM_SE call_state) { LGBM_SE call_state) {
R_API_BEGIN(); R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int64_t len = 0; int64_t len = 0;
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row), CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row),
pred_type, R_AS_INT(num_iteration), &len)); pred_type, R_AS_INT(num_iteration), &len));
...@@ -532,13 +537,14 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle, ...@@ -532,13 +537,14 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
LGBM_SE num_row, LGBM_SE num_row,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE out_result, LGBM_SE out_result,
LGBM_SE call_state) { LGBM_SE call_state) {
R_API_BEGIN(); R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
const int* p_indptr = R_INT_PTR(indptr); const int* p_indptr = R_INT_PTR(indptr);
const int* p_indices = R_INT_PTR(indices); const int* p_indices = R_INT_PTR(indices);
...@@ -562,13 +568,14 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle, ...@@ -562,13 +568,14 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE num_col, LGBM_SE num_col,
LGBM_SE is_rawscore, LGBM_SE is_rawscore,
LGBM_SE is_leafidx, LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE out_result, LGBM_SE out_result,
LGBM_SE call_state) { LGBM_SE call_state) {
R_API_BEGIN(); R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int32_t nrow = R_AS_INT(num_row); int32_t nrow = R_AS_INT(num_row);
int32_t ncol = R_AS_INT(num_col); int32_t ncol = R_AS_INT(num_col);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment