Commit cb9fabda authored by Eric Graves's avatar Eric Graves Committed by Guolin Ke
Browse files

added code to expose c_api_pred_contrib in the R package (#1259)

* added code to expose c_api_pred_contrib in the R package

* removed Rprintf

* reverted to previous version of install.libs.R
parent 48ff86e6
......@@ -380,3 +380,4 @@ lightgbm.model
# duplicate version file
python-package/lightgbm/VERSION.txt
.Rproj.user
......@@ -615,6 +615,7 @@ Booster <- R6Class(
#' sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} for
#' logistic regression would result in predictions for log-odds instead of probabilities.
#' @param predleaf whether predict leaf index instead.
#' @param predcontrib return per-feature contributions for each record.
#' @param header only used for prediction for text file. True if text file has header
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
#' prediction outputs per case.
......@@ -655,6 +656,7 @@ predict.lgb.Booster <- function(object, data,
num_iteration = NULL,
rawscore = FALSE,
predleaf = FALSE,
predcontrib = FALSE,
header = FALSE,
reshape = FALSE, ...) {
......@@ -668,6 +670,7 @@ predict.lgb.Booster <- function(object, data,
num_iteration,
rawscore,
predleaf,
predcontrib,
header,
reshape, ...)
}
......
......@@ -63,6 +63,7 @@ Predictor <- R6Class(
num_iteration = NULL,
rawscore = FALSE,
predleaf = FALSE,
predcontrib = FALSE,
header = FALSE,
reshape = FALSE) {
......@@ -86,6 +87,7 @@ Predictor <- R6Class(
as.integer(header),
as.integer(rawscore),
as.integer(predleaf),
as.integer(predcontrib),
as.integer(num_iteration),
private$params,
lgb.c_str(tmp_filename))
......@@ -99,6 +101,7 @@ Predictor <- R6Class(
# Not a file, we need to predict from R object
num_row <- nrow(data)
npred <- 0L
# Check number of predictions to do
......@@ -108,6 +111,7 @@ Predictor <- R6Class(
as.integer(num_row),
as.integer(rawscore),
as.integer(predleaf),
as.integer(predcontrib),
as.integer(num_iteration))
# Pre-allocate empty vector
......@@ -123,6 +127,7 @@ Predictor <- R6Class(
as.integer(ncol(data)),
as.integer(rawscore),
as.integer(predleaf),
as.integer(predcontrib),
as.integer(num_iteration),
private$params)
......@@ -142,6 +147,7 @@ Predictor <- R6Class(
nrow(data),
as.integer(rawscore),
as.integer(predleaf),
as.integer(predcontrib),
as.integer(num_iteration),
private$params)
......@@ -165,7 +171,7 @@ Predictor <- R6Class(
# Data reshaping
if (predleaf) {
if (predleaf | predcontrib) {
# Predict leaves only, reshaping is mandatory
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
......
......@@ -388,6 +388,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
LGBM_SE data_has_header,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE result_filename,
......@@ -407,6 +408,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
LGBM_SE num_row,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE out_len,
LGBM_SE call_state);
......@@ -438,6 +440,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
LGBM_SE num_row,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
......@@ -464,6 +467,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE ncol,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
......
......@@ -479,7 +479,7 @@ LGBM_SE LGBM_BoosterGetPredict_R(LGBM_SE handle,
R_API_END();
}
int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx) {
int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx, LGBM_SE is_predcontrib) {
int pred_type = C_API_PREDICT_NORMAL;
if (R_AS_INT(is_rawscore)) {
pred_type = C_API_PREDICT_RAW_SCORE;
......@@ -487,6 +487,9 @@ int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx) {
if (R_AS_INT(is_leafidx)) {
pred_type = C_API_PREDICT_LEAF_INDEX;
}
if (R_AS_INT(is_predcontrib)) {
pred_type = C_API_PREDICT_CONTRIB;
}
return pred_type;
}
......@@ -495,12 +498,13 @@ LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
LGBM_SE data_has_header,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE result_filename,
LGBM_SE call_state) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
R_CHAR_PTR(result_filename)));
......@@ -511,11 +515,12 @@ LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
LGBM_SE num_row,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE out_len,
LGBM_SE call_state) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int64_t len = 0;
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row),
pred_type, R_AS_INT(num_iteration), &len));
......@@ -532,13 +537,14 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
LGBM_SE num_row,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
LGBM_SE call_state) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
const int* p_indptr = R_INT_PTR(indptr);
const int* p_indices = R_INT_PTR(indices);
......@@ -562,13 +568,14 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE num_col,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
LGBM_SE call_state) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int32_t nrow = R_AS_INT(num_row);
int32_t ncol = R_AS_INT(num_col);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment