Unverified Commit 7d9106d2 authored by Michael Mayer's avatar Michael Mayer Committed by GitHub
Browse files

[R-package]: add num_trees_per_iter, num_trees, and num_iter methods (#6500)

parent 3a98ea13
......@@ -307,6 +307,46 @@ Booster <- R6::R6Class(
},
# Number of trees per iteration
num_trees_per_iter = function() {
self$restore_handle()
trees_per_iter <- 1L
.Call(
LGBM_BoosterNumModelPerIteration_R
, private$handle
, trees_per_iter
)
return(trees_per_iter)
},
# Total number of trees
num_trees = function() {
self$restore_handle()
ntrees <- 0L
.Call(
LGBM_BoosterNumberOfTotalModel_R
, private$handle
, ntrees
)
return(ntrees)
},
# Number of iterations (= rounds)
num_iter = function() {
ntrees <- self$num_trees()
trees_per_iter <- self$num_trees_per_iter()
return(ntrees / trees_per_iter)
},
# Get upper bound
upper_bound = function() {
......
......@@ -763,8 +763,7 @@ SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
R_API_END();
}
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
SEXP out) {
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, SEXP out) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int out_iteration;
......@@ -774,6 +773,26 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
R_API_END();
}
SEXP LGBM_BoosterNumModelPerIteration_R(SEXP handle, SEXP out) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int models_per_iter;
CHECK_CALL(LGBM_BoosterNumModelPerIteration(R_ExternalPtrAddr(handle), &models_per_iter));
INTEGER(out)[0] = models_per_iter;
return R_NilValue;
R_API_END();
}
SEXP LGBM_BoosterNumberOfTotalModel_R(SEXP handle, SEXP out) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int total_models;
CHECK_CALL(LGBM_BoosterNumberOfTotalModel(R_ExternalPtrAddr(handle), &total_models));
INTEGER(out)[0] = total_models;
return R_NilValue;
R_API_END();
}
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
SEXP out_result) {
R_API_BEGIN();
......@@ -1431,6 +1450,8 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterUpdateOneIterCustom_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R , 4},
{"LGBM_BoosterRollbackOneIter_R" , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R , 1},
{"LGBM_BoosterGetCurrentIteration_R" , (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R , 2},
{"LGBM_BoosterNumModelPerIteration_R" , (DL_FUNC) &LGBM_BoosterNumModelPerIteration_R , 2},
{"LGBM_BoosterNumberOfTotalModel_R" , (DL_FUNC) &LGBM_BoosterNumberOfTotalModel_R , 2},
{"LGBM_BoosterGetUpperBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R , 2},
{"LGBM_BoosterGetLowerBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R , 2},
{"LGBM_BoosterGetEvalNames_R" , (DL_FUNC) &LGBM_BoosterGetEvalNames_R , 1},
......
......@@ -384,6 +384,28 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetCurrentIteration_R(
SEXP out
);
/*!
* \brief Get number of trees per iteration
* \param handle Booster handle
* \param out Number of trees per iteration
* \return R NULL value
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterNumModelPerIteration_R(
SEXP handle,
SEXP out
);
/*!
* \brief Get total number of trees
* \param handle Booster handle
* \param out Total number of trees of Booster
* \return R NULL value
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterNumberOfTotalModel_R(
SEXP handle,
SEXP out
);
/*!
* \brief Get model upper bound value.
* \param handle Handle of Booster
......
......@@ -623,6 +623,174 @@ test_that("Booster$update() throws an informative error if you provide a non-Dat
}, regexp = "lgb.Booster.update: Only can use lgb.Dataset", fixed = TRUE)
})
test_that("Booster$num_trees_per_iter() works as expected", {
set.seed(708L)
X <- data.matrix(iris[2L:4L])
y_reg <- iris[, 1L]
y_binary <- as.integer(y_reg > median(y_reg))
y_class <- as.integer(iris[, 5L]) - 1L
num_class <- 3L
nrounds <- 10L
# Regression and binary probabilistic classification (1 iteration = 1 tree)
fit_reg <- lgb.train(
params = list(
objective = "mse"
, verbose = .LGB_VERBOSITY
, num_threads = .LGB_MAX_THREADS
)
, data = lgb.Dataset(X, label = y_reg)
, nrounds = nrounds
)
fit_binary <- lgb.train(
params = list(
objective = "binary"
, verbose = .LGB_VERBOSITY
, num_threads = .LGB_MAX_THREADS
)
, data = lgb.Dataset(X, label = y_binary)
, nrounds = nrounds
)
# Multiclass probabilistic classification (1 iteration = num_class trees)
fit_class <- lgb.train(
params = list(
objective = "multiclass"
, verbose = .LGB_VERBOSITY
, num_threads = .LGB_MAX_THREADS
, num_class = num_class
)
, data = lgb.Dataset(X, label = y_class)
, nrounds = nrounds
)
expect_equal(fit_reg$num_trees_per_iter(), 1L)
expect_equal(fit_binary$num_trees_per_iter(), 1L)
expect_equal(fit_class$num_trees_per_iter(), num_class)
})
test_that("Booster$num_trees() and $num_iter() works (no early stopping)", {
set.seed(708L)
X <- data.matrix(iris[2L:4L])
y_reg <- iris[, 1L]
y_binary <- as.integer(y_reg > median(y_reg))
y_class <- as.integer(iris[, 5L]) - 1L
num_class <- 3L
nrounds <- 10L
# Regression and binary probabilistic classification (1 iteration = 1 tree)
fit_reg <- lgb.train(
params = list(
objective = "mse"
, verbose = .LGB_VERBOSITY
, num_threads = .LGB_MAX_THREADS
)
, data = lgb.Dataset(X, label = y_reg)
, nrounds = nrounds
)
fit_binary <- lgb.train(
params = list(
objective = "binary"
, verbose = .LGB_VERBOSITY
, num_threads = .LGB_MAX_THREADS
)
, data = lgb.Dataset(X, label = y_binary)
, nrounds = nrounds
)
# Multiclass probabilistic classification (1 iteration = num_class trees)
fit_class <- lgb.train(
params = list(
objective = "multiclass"
, verbose = .LGB_VERBOSITY
, num_threads = .LGB_MAX_THREADS
, num_class = num_class
)
, data = lgb.Dataset(X, label = y_class)
, nrounds = nrounds
)
expect_equal(fit_reg$num_trees(), nrounds)
expect_equal(fit_binary$num_trees(), nrounds)
expect_equal(fit_class$num_trees(), num_class * nrounds)
expect_equal(fit_reg$num_iter(), nrounds)
expect_equal(fit_binary$num_iter(), nrounds)
expect_equal(fit_class$num_iter(), nrounds)
})
test_that("Booster$num_trees() and $num_iter() work (with early stopping)", {
set.seed(708L)
X <- data.matrix(iris[2L:4L])
y_reg <- iris[, 1L]
y_binary <- as.integer(y_reg > median(y_reg))
y_class <- as.integer(iris[, 5L]) - 1L
train_ix <- c(1L:40L, 51L:90L, 101L:140L)
X_train <- X[train_ix, ]
X_valid <- X[-train_ix, ]
num_class <- 3L
nrounds <- 1000L
early_stopping <- 2L
# Regression and binary probabilistic classification (1 iteration = 1 tree)
fit_reg <- lgb.train(
params = list(
objective = "mse"
, verbose = .LGB_VERBOSITY
, num_threads = .LGB_MAX_THREADS
)
, data = lgb.Dataset(X_train, label = y_reg[train_ix])
, valids = list(valid = lgb.Dataset(X_valid, label = y_reg[-train_ix]))
, nrounds = nrounds
, early_stopping_round = early_stopping
)
fit_binary <- lgb.train(
params = list(
objective = "binary"
, verbose = .LGB_VERBOSITY
, num_threads = .LGB_MAX_THREADS
)
, data = lgb.Dataset(X_train, label = y_binary[train_ix])
, valids = list(valid = lgb.Dataset(X_valid, label = y_binary[-train_ix]))
, nrounds = nrounds
, early_stopping_round = early_stopping
)
# Multiclass probabilistic classification (1 iteration = num_class trees)
fit_class <- lgb.train(
params = list(
objective = "multiclass"
, verbose = .LGB_VERBOSITY
, num_threads = .LGB_MAX_THREADS
, num_class = num_class
)
, data = lgb.Dataset(X_train, label = y_class[train_ix])
, valids = list(valid = lgb.Dataset(X_valid, label = y_class[-train_ix]))
, nrounds = nrounds
, early_stopping_round = early_stopping
)
expected_trees_reg <- fit_reg$best_iter + early_stopping
expected_trees_binary <- fit_binary$best_iter + early_stopping
expected_trees_class <- (fit_class$best_iter + early_stopping) * num_class
expect_equal(fit_reg$num_trees(), expected_trees_reg)
expect_equal(fit_binary$num_trees(), expected_trees_binary)
expect_equal(fit_class$num_trees(), expected_trees_class)
expect_equal(fit_reg$num_iter(), expected_trees_reg)
expect_equal(fit_binary$num_iter(), expected_trees_binary)
expect_equal(fit_class$num_iter(), expected_trees_class / num_class)
})
test_that("Booster should store parameters and Booster$reset_parameter() should update them", {
data(agaricus.train, package = "lightgbm")
dtrain <- lgb.Dataset(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment