Unverified Commit 2f59773d authored by david-cortes's avatar david-cortes Committed by GitHub
Browse files

[R-package] Add `print()` and `summary()` methods for Booster (#4686)



* add print and summary S3 method

* correct wrong signature

* attempt at bypassing linter

* Update R-package/R/lgb.Booster.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/src/lightgbm_R.h
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update include/LightGBM/c_api.h
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* add more tests

* linter

* don't pluralize single tree

* remove duplicated function

* update changed function name

* missing declaration

* Update lightgbm_R.h

* Update R-package/tests/testthat/test_lgb.Booster.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* accommodate custom objectives in print

* linter

* linter
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent 6e6fb14c
...@@ -6,9 +6,11 @@ S3method(dimnames,lgb.Dataset) ...@@ -6,9 +6,11 @@ S3method(dimnames,lgb.Dataset)
S3method(get_field,lgb.Dataset) S3method(get_field,lgb.Dataset)
S3method(getinfo,lgb.Dataset) S3method(getinfo,lgb.Dataset)
S3method(predict,lgb.Booster) S3method(predict,lgb.Booster)
S3method(print,lgb.Booster)
S3method(set_field,lgb.Dataset) S3method(set_field,lgb.Dataset)
S3method(setinfo,lgb.Dataset) S3method(setinfo,lgb.Dataset)
S3method(slice,lgb.Dataset) S3method(slice,lgb.Dataset)
S3method(summary,lgb.Booster)
export(get_field) export(get_field)
export(getinfo) export(getinfo)
export(lgb.Dataset) export(lgb.Dataset)
......
...@@ -814,6 +814,65 @@ predict.lgb.Booster <- function(object, ...@@ -814,6 +814,65 @@ predict.lgb.Booster <- function(object,
) )
} }
#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `x`, returned as invisible.
#' @export
print.lgb.Booster <- function(x, ...) {
# nolint start
handle <- x$.__enclos_env__$private$handle
handle_is_null <- lgb.is.null.handle(handle)
if (!handle_is_null) {
ntrees <- x$current_iter()
if (ntrees == 1L) {
cat("LightGBM Model (1 tree)\n")
} else {
cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
}
} else {
cat("LightGBM Model\n")
}
if (!handle_is_null) {
obj <- x$params$objective
if (obj == "none") {
obj <- "custom"
}
if (x$.__enclos_env__$private$num_class == 1L) {
cat(sprintf("Objective: %s\n", obj))
} else {
cat(sprintf("Objective: %s (%d classes)\n"
, obj
, x$.__enclos_env__$private$num_class))
}
} else {
cat("(Booster handle is invalid)\n")
}
if (!handle_is_null) {
ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
cat(sprintf("Fitted to dataset with %d columns\n", ncols))
}
# nolint end
return(invisible(x))
}
#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `object`, returned as invisible.
#' @export
summary.lgb.Booster <- function(object, ...) {
print(object)
}
#' @name lgb.load #' @name lgb.load
#' @title Load LightGBM model #' @title Load LightGBM model
#' @description Load LightGBM takes in either a file path or model string. #' @description Load LightGBM takes in either a file path or model string.
......
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/lgb.Booster.R
\name{print.lgb.Booster}
\alias{print.lgb.Booster}
\title{Print method for LightGBM model}
\usage{
\method{print}{lgb.Booster}(x, ...)
}
\arguments{
\item{x}{Object of class \code{lgb.Booster}}
\item{...}{Not used}
}
\value{
The same input `x`, returned as invisible.
}
\description{
Show summary information about a LightGBM model object (same as \code{summary}).
}
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/lgb.Booster.R
\name{summary.lgb.Booster}
\alias{summary.lgb.Booster}
\title{Summary method for LightGBM model}
\usage{
\method{summary}{lgb.Booster}(object, ...)
}
\arguments{
\item{object}{Object of class \code{lgb.Booster}}
\item{...}{Not used}
}
\value{
The same input `object`, returned as invisible.
}
\description{
Show summary information about a LightGBM model object (same as \code{print}).
}
...@@ -525,6 +525,15 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle, ...@@ -525,6 +525,15 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterGetNumFeature_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int out = 0;
CHECK_CALL(LGBM_BoosterGetNumFeature(R_ExternalPtrAddr(handle), &out));
return Rf_ScalarInteger(out);
R_API_END();
}
SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) { SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
R_API_BEGIN(); R_API_BEGIN();
_AssertBoosterHandleNotNull(handle); _AssertBoosterHandleNotNull(handle);
...@@ -889,6 +898,7 @@ static const R_CallMethodDef CallEntries[] = { ...@@ -889,6 +898,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterResetTrainingData_R" , (DL_FUNC) &LGBM_BoosterResetTrainingData_R , 2}, {"LGBM_BoosterResetTrainingData_R" , (DL_FUNC) &LGBM_BoosterResetTrainingData_R , 2},
{"LGBM_BoosterResetParameter_R" , (DL_FUNC) &LGBM_BoosterResetParameter_R , 2}, {"LGBM_BoosterResetParameter_R" , (DL_FUNC) &LGBM_BoosterResetParameter_R , 2},
{"LGBM_BoosterGetNumClasses_R" , (DL_FUNC) &LGBM_BoosterGetNumClasses_R , 2}, {"LGBM_BoosterGetNumClasses_R" , (DL_FUNC) &LGBM_BoosterGetNumClasses_R , 2},
{"LGBM_BoosterGetNumFeature_R" , (DL_FUNC) &LGBM_BoosterGetNumFeature_R , 1},
{"LGBM_BoosterUpdateOneIter_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R , 1}, {"LGBM_BoosterUpdateOneIter_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R , 1},
{"LGBM_BoosterUpdateOneIterCustom_R", (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R, 4}, {"LGBM_BoosterUpdateOneIterCustom_R", (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R, 4},
{"LGBM_BoosterRollbackOneIter_R" , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R , 1}, {"LGBM_BoosterRollbackOneIter_R" , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R , 1},
......
...@@ -302,6 +302,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumClasses_R( ...@@ -302,6 +302,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumClasses_R(
SEXP out SEXP out
); );
/*!
* \brief Get number of features.
* \param handle Booster handle
* \return Total number of features, as R integer
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumFeature_R(
SEXP handle
);
/*! /*!
* \brief update the model in one round * \brief update the model in one round
* \param handle Booster handle * \param handle Booster handle
......
...@@ -1041,3 +1041,116 @@ test_that("boosters with linear models at leaves can be written to RDS and re-lo ...@@ -1041,3 +1041,116 @@ test_that("boosters with linear models at leaves can be written to RDS and re-lo
preds2 <- predict(bst2, X) preds2 <- predict(bst2, X)
expect_identical(preds, preds2) expect_identical(preds, preds2)
}) })
test_that("Booster's print, show, and summary work correctly", {
.have_same_handle <- function(model, other_model) {
expect_equal(
model$.__enclos_env__$private$handle
, other_model$.__enclos_env__$private$handle
)
}
.check_methods_work <- function(model) {
# should work for fitted models
ret <- print(model)
.have_same_handle(ret, model)
ret <- show(model)
expect_null(ret)
ret <- summary(model)
.have_same_handle(ret, model)
# should not fail for finalized models
model$finalize()
ret <- print(model)
.have_same_handle(ret, model)
ret <- show(model)
expect_null(ret)
ret <- summary(model)
.have_same_handle(ret, model)
}
data("mtcars")
model <- lgb.train(
params = list(objective = "regression")
, data = lgb.Dataset(
as.matrix(mtcars[, -1L])
, label = mtcars$mpg)
, verbose = 0L
, nrounds = 5L
)
.check_methods_work(model)
data("iris")
model <- lgb.train(
params = list(objective = "multiclass", num_class = 3L)
, data = lgb.Dataset(
as.matrix(iris[, -5L])
, label = as.numeric(factor(iris$Species)) - 1.0
)
, verbose = 0L
, nrounds = 5L
)
.check_methods_work(model)
# with custom objective
.logregobj <- function(preds, dtrain) {
labels <- get_field(dtrain, "label")
preds <- 1.0 / (1.0 + exp(-preds))
grad <- preds - labels
hess <- preds * (1.0 - preds)
return(list(grad = grad, hess = hess))
}
.evalerror <- function(preds, dtrain) {
labels <- get_field(dtrain, "label")
preds <- 1.0 / (1.0 + exp(-preds))
err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels)
return(list(
name = "error"
, value = err
, higher_better = FALSE
))
}
model <- lgb.train(
data = lgb.Dataset(
as.matrix(iris[, -5L])
, label = as.numeric(iris$Species == "virginica")
)
, obj = .logregobj
, eval = .evalerror
, verbose = 0L
, nrounds = 5L
)
.check_methods_work(model)
})
test_that("LGBM_BoosterGetNumFeature_R returns correct outputs", {
data("mtcars")
model <- lgb.train(
params = list(objective = "regression")
, data = lgb.Dataset(
as.matrix(mtcars[, -1L])
, label = mtcars$mpg)
, verbose = 0L
, nrounds = 5L
)
ncols <- .Call(LGBM_BoosterGetNumFeature_R, model$.__enclos_env__$private$handle)
expect_equal(ncols, ncol(mtcars) - 1L)
data("iris")
model <- lgb.train(
params = list(objective = "multiclass", num_class = 3L)
, data = lgb.Dataset(
as.matrix(iris[, -5L])
, label = as.numeric(factor(iris$Species)) - 1.0
)
, verbose = 0L
, nrounds = 5L
)
ncols <- .Call(LGBM_BoosterGetNumFeature_R, model$.__enclos_env__$private$handle)
expect_equal(ncols, ncol(iris) - 1L)
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment