"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "fffd066cb331a3573fc8565915c914ae5b6b8313"
Unverified Commit 85be04a6 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] Disabled early stopping when using 'dart' boosting strategy (#2443)

parent fc991c9d
# Central location for parameter aliases.
# See https://lightgbm.readthedocs.io/en/latest/Parameters.html#core-parameters
# [description] List of respected parameter aliases. Wrapped in a function to take advantage of
# lazy evaluation (so it doesn't matter what order R sources files during installation).
# [return] A named list, where each key is a main LightGBM parameter and each value is a character
# vector of corresponding aliases.
.PARAMETER_ALIASES <- function(){
return(list(
"boosting" = c(
"boosting"
, "boost"
, "boosting_type"
)
, "early_stopping_round" = c(
"early_stopping_round"
, "early_stopping_rounds"
, "early_stopping"
, "n_iter_no_change"
)
, "metric" = c(
"metric"
, "metrics"
, "metric_types"
)
, "num_class" = c(
"num_class"
, "num_classes"
)
, "num_iterations" = c(
"num_iterations"
, "num_iteration"
, "n_iter"
, "num_tree"
, "num_trees"
, "num_round"
, "num_rounds"
, "num_boost_round"
, "n_estimators"
)
))
}
...@@ -37,7 +37,11 @@ cb.reset.parameters <- function(new_params) { ...@@ -37,7 +37,11 @@ cb.reset.parameters <- function(new_params) {
# Some parameters are not allowed to be changed, # Some parameters are not allowed to be changed,
# since changing them would simply wreck some chaos # since changing them would simply wreck some chaos
not_allowed <- c("num_class", "metric", "boosting_type") not_allowed <- c(
.PARAMETER_ALIASES()[["num_class"]]
, .PARAMETER_ALIASES()[["metric"]]
, .PARAMETER_ALIASES()[["boosting"]]
)
if (any(pnames %in% not_allowed)) { if (any(pnames %in% not_allowed)) {
stop( stop(
"Parameters " "Parameters "
......
...@@ -136,17 +136,7 @@ lgb.cv <- function(params = list(), ...@@ -136,17 +136,7 @@ lgb.cv <- function(params = list(),
begin_iteration <- predictor$current_iter() + 1 begin_iteration <- predictor$current_iter() + 1
} }
# Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one # Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one
n_trees <- c( n_trees <- .PARAMETER_ALIASES()[["num_iterations"]]
"num_iterations"
, "num_iteration"
, "n_iter"
, "num_tree"
, "num_trees"
, "num_round"
, "num_rounds"
, "num_boost_round"
, "n_estimators"
)
if (any(names(params) %in% n_trees)) { if (any(names(params) %in% n_trees)) {
end_iteration <- begin_iteration + params[[which(names(params) %in% n_trees)[1]]] - 1 end_iteration <- begin_iteration + params[[which(names(params) %in% n_trees)[1]]] - 1
} else { } else {
...@@ -225,30 +215,52 @@ lgb.cv <- function(params = list(), ...@@ -225,30 +215,52 @@ lgb.cv <- function(params = list(),
callbacks <- add.cb(callbacks, cb.record.evaluation()) callbacks <- add.cb(callbacks, cb.record.evaluation())
} }
# Check for early stopping passed as parameter when adding early stopping callback # If early stopping was passed as a parameter in params(), prefer that to keyword argument
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping", "n_iter_no_change") # early_stopping_rounds by overwriting the value in 'early_stopping_rounds'
if (any(names(params) %in% early_stop)) { early_stop <- .PARAMETER_ALIASES()[["early_stopping_round"]]
if (params[[which(names(params) %in% early_stop)[1]]] > 0) { early_stop_param_indx <- names(params) %in% early_stop
callbacks <- add.cb( if (any(early_stop_param_indx)) {
callbacks first_early_stop_param <- which(early_stop_param_indx)[[1]]
, cb.early.stop( first_early_stop_param_name <- names(params)[[first_early_stop_param]]
params[[which(names(params) %in% early_stop)[1]]] early_stopping_rounds <- params[[first_early_stop_param_name]]
, verbose = verbose }
)
) # Did user pass parameters that indicate they want to use early stopping?
} using_early_stopping_via_args <- !is.null(early_stopping_rounds)
} else {
if (!is.null(early_stopping_rounds)) { boosting_param_names <- .PARAMETER_ALIASES()[["boosting"]]
if (early_stopping_rounds > 0) { using_dart <- any(
callbacks <- add.cb( sapply(
callbacks X = boosting_param_names
, cb.early.stop( , FUN = function(param){
early_stopping_rounds identical(params[[param]], 'dart')
, verbose = verbose
)
)
} }
} )
)
# Cannot use early stopping with 'dart' boosting
if (using_dart){
warning("Early stopping is not available in 'dart' mode.")
using_early_stopping_via_args <- FALSE
# Remove the cb.early.stop() function if it was passed in to callbacks
callbacks <- Filter(
f = function(cb_func){
!identical(attr(cb_func, "name"), "cb.early.stop")
}
, x = callbacks
)
}
# If user supplied early_stopping_rounds, add the early stopping callback
if (using_early_stopping_via_args){
callbacks <- add.cb(
callbacks
, cb.early.stop(
stopping_rounds = early_stopping_rounds
, verbose = verbose
)
)
} }
# Categorize callbacks # Categorize callbacks
......
...@@ -108,24 +108,13 @@ lgb.train <- function(params = list(), ...@@ -108,24 +108,13 @@ lgb.train <- function(params = list(),
begin_iteration <- predictor$current_iter() + 1 begin_iteration <- predictor$current_iter() + 1
} }
# Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one # Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one
n_rounds <- c( n_trees <- .PARAMETER_ALIASES()[["num_iterations"]]
"num_iterations" if (any(names(params) %in% n_trees)) {
, "num_iteration" end_iteration <- begin_iteration + params[[which(names(params) %in% n_trees)[1]]] - 1
, "n_iter"
, "num_tree"
, "num_trees"
, "num_round"
, "num_rounds"
, "num_boost_round"
, "n_estimators"
)
if (any(names(params) %in% n_rounds)) {
end_iteration <- begin_iteration + params[[which(names(params) %in% n_rounds)[1]]] - 1
} else { } else {
end_iteration <- begin_iteration + nrounds - 1 end_iteration <- begin_iteration + nrounds - 1
} }
# Check for training dataset type correctness # Check for training dataset type correctness
if (!lgb.is.Dataset(data)) { if (!lgb.is.Dataset(data)) {
stop("lgb.train: data only accepts lgb.Dataset object") stop("lgb.train: data only accepts lgb.Dataset object")
...@@ -207,30 +196,52 @@ lgb.train <- function(params = list(), ...@@ -207,30 +196,52 @@ lgb.train <- function(params = list(),
callbacks <- add.cb(callbacks, cb.record.evaluation()) callbacks <- add.cb(callbacks, cb.record.evaluation())
} }
# Check for early stopping passed as parameter when adding early stopping callback # If early stopping was passed as a parameter in params(), prefer that to keyword argument
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping", "n_iter_no_change") # early_stopping_rounds by overwriting the value in 'early_stopping_rounds'
if (any(names(params) %in% early_stop)) { early_stop <- .PARAMETER_ALIASES()[["early_stopping_round"]]
if (params[[which(names(params) %in% early_stop)[1]]] > 0) { early_stop_param_indx <- names(params) %in% early_stop
callbacks <- add.cb( if (any(early_stop_param_indx)) {
callbacks first_early_stop_param <- which(early_stop_param_indx)[[1]]
, cb.early.stop( first_early_stop_param_name <- names(params)[[first_early_stop_param]]
params[[which(names(params) %in% early_stop)[1]]] early_stopping_rounds <- params[[first_early_stop_param_name]]
, verbose = verbose }
)
) # Did user pass parameters that indicate they want to use early stopping?
} using_early_stopping_via_args <- !is.null(early_stopping_rounds)
} else {
if (!is.null(early_stopping_rounds)) { boosting_param_names <- .PARAMETER_ALIASES()[["boosting"]]
if (early_stopping_rounds > 0) { using_dart <- any(
callbacks <- add.cb( sapply(
callbacks X = boosting_param_names
, cb.early.stop( , FUN = function(param){
early_stopping_rounds identical(params[[param]], 'dart')
, verbose = verbose
)
)
} }
} )
)
# Cannot use early stopping with 'dart' boosting
if (using_dart){
warning("Early stopping is not available in 'dart' mode.")
using_early_stopping_via_args <- FALSE
# Remove the cb.early.stop() function if it was passed in to callbacks
callbacks <- Filter(
f = function(cb_func){
!identical(attr(cb_func, "name"), "cb.early.stop")
}
, x = callbacks
)
}
# If user supplied early_stopping_rounds, add the early stopping callback
if (using_early_stopping_via_args){
callbacks <- add.cb(
callbacks
, cb.early.stop(
stopping_rounds = early_stopping_rounds
, verbose = verbose
)
)
} }
# "Categorize" callbacks # "Categorize" callbacks
......
...@@ -43,3 +43,35 @@ test_that("Feature penalties work properly", { ...@@ -43,3 +43,35 @@ test_that("Feature penalties work properly", {
# Ensure that feature is not used when feature_penalty = 0 # Ensure that feature is not used when feature_penalty = 0
expect_length(var_gain[[length(var_gain)]], 0) expect_length(var_gain[[length(var_gain)]], 0)
}) })
expect_true(".PARAMETER_ALIASES() returns a named list", {
param_aliases <- .PARAMETER_ALIASES()
expect_true(is.list(param_aliases))
expect_true(is.character(names(param_aliases)))
expect_true(is.character(param_aliases[["boosting"]]))
expect_true(is.character(param_aliases[["early_stopping_round"]]))
expect_true(is.character(param_aliases[["metric"]]))
expect_true(is.character(param_aliases[["num_class"]]))
expect_true(is.character(param_aliases[["num_iterations"]]))
})
expect_true("training should warn if you use 'dart' boosting, specified with 'boosting' or aliases", {
for (boosting_param in .PARAMETER_ALIASES()[["boosting"]]){
expect_warning({
result <- lightgbm(
data = train$data
, label = train$label
, num_leaves = 5
, learning_rate = 0.05
, nrounds = 5
, objective = "binary"
, metric = "binary_error"
, verbose = -1
, params = stats::setNames(
object = "dart"
, nm = boosting_param
)
)
}, regexp = "Early stopping is not available in 'dart' mode")
}
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment