"python-package/vscode:/vscode.git/clone" did not exist on "d93eb338481be57e9a94d0aa742d978322d03cda"
Commit f90dfaf3 authored by Bernie Gray's avatar Bernie Gray Committed by Guolin Ke
Browse files

minor changes to R package for robustness and efficiency (#910)

* minor changes for robustness and efficiency

* fix charcter typo

* don't use lgb.is.* in filter

* revert to cat instead of message; don't use mget

* fix bugs in previous commits

* should be good to go now
parent f01fba9f
...@@ -132,12 +132,7 @@ merge.eval.string <- function(env) { ...@@ -132,12 +132,7 @@ merge.eval.string <- function(env) {
msg <- list(sprintf("[%d]:", env$iteration)) msg <- list(sprintf("[%d]:", env$iteration))
# Set if evaluation error # Set if evaluation error
is_eval_err <- FALSE is_eval_error <- length(env$eval_err_list) > 0
# Check evaluation error list length
if (length(env$eval_err_list) > 0) {
is_eval_err <- TRUE
}
# Loop through evaluation list # Loop through evaluation list
for (j in seq_along(env$eval_list)) { for (j in seq_along(env$eval_list)) {
...@@ -170,7 +165,7 @@ cb.print.evaluation <- function(period = 1) { ...@@ -170,7 +165,7 @@ cb.print.evaluation <- function(period = 1) {
i <- env$iteration i <- env$iteration
# Check if iteration matches moduo # Check if iteration matches moduo
if ((i - 1) %% period == 0 | i == env$begin_iteration | i == env$end_iteration ) { if ((i - 1) %% period == 0 || is.element(i, c(env$begin_iteration, env$end_iteration ))) {
# Merge evaluation string # Merge evaluation string
msg <- merge.eval.string(env) msg <- merge.eval.string(env)
...@@ -206,12 +201,7 @@ cb.record.evaluation <- function() { ...@@ -206,12 +201,7 @@ cb.record.evaluation <- function() {
} }
# Set if evaluation error # Set if evaluation error
is_eval_err <- FALSE is_eval_err <- length(env$eval_err_list) > 0
# Check evaluation error list length
if (length(env$eval_err_list) > 0) {
is_eval_err <- TRUE
}
# Check length of recorded evaluation # Check length of recorded evaluation
if (length(env$model$record_evals) == 0) { if (length(env$model$record_evals) == 0) {
...@@ -295,9 +285,9 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) { ...@@ -295,9 +285,9 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
} }
# Maximization or minimization task # Maximization or minimization task
factor_to_bigger_better <<- rep(1.0, eval_len) factor_to_bigger_better <<- rep.int(1.0, eval_len)
best_iter <<- rep(-1, eval_len) best_iter <<- rep.int(-1, eval_len)
best_score <<- rep(-Inf, eval_len) best_score <<- rep.int(-Inf, eval_len)
best_msg <<- list() best_msg <<- list()
# Loop through evaluation elements # Loop through evaluation elements
......
...@@ -224,7 +224,7 @@ Booster <- R6Class( ...@@ -224,7 +224,7 @@ Booster <- R6Class(
gpair <- fobj(private$inner_predict(1), private$train_set) gpair <- fobj(private$inner_predict(1), private$train_set)
# Check for gradient and hessian as list # Check for gradient and hessian as list
if(is.null(gpair$grad) | is.null(gpair$hess)){ if(is.null(gpair$grad) || is.null(gpair$hess)){
stop("lgb.Booster.update: custom objective should stop("lgb.Booster.update: custom objective should
return a list with attributes (hess, grad)") return a list with attributes (hess, grad)")
} }
...@@ -510,16 +510,7 @@ Booster <- R6Class( ...@@ -510,16 +510,7 @@ Booster <- R6Class(
# Parse and store privately names # Parse and store privately names
names <- strsplit(names, "\t")[[1]] names <- strsplit(names, "\t")[[1]]
private$eval_names <- names private$eval_names <- names
private$higher_better_inner_eval <- rep(FALSE, length(names)) private$higher_better_inner_eval <- grepl("^ndcg|^auc$", names)
# Loop through each name to pick up evaluation (and parse ndcg manually)
for (i in seq_along(names)) {
if ((names[i] == "auc") | grepl("^ndcg", names[i])) {
private$higher_better_inner_eval[i] <- TRUE
}
}
} }
...@@ -589,7 +580,7 @@ Booster <- R6Class( ...@@ -589,7 +580,7 @@ Booster <- R6Class(
res <- feval(private$inner_predict(data_idx), data) res <- feval(private$inner_predict(data_idx), data)
# Check for name correctness # Check for name correctness
if(is.null(res$name) | is.null(res$value) | is.null(res$higher_better)) { if(is.null(res$name) || is.null(res$value) || is.null(res$higher_better)) {
stop("lgb.Booster.eval: custom eval function should return a stop("lgb.Booster.eval: custom eval function should return a
list with attribute (name, value, higher_better)"); list with attribute (name, value, higher_better)");
} }
......
...@@ -125,7 +125,7 @@ Dataset <- R6Class( ...@@ -125,7 +125,7 @@ Dataset <- R6Class(
if (!is.null(private$categorical_feature)) { if (!is.null(private$categorical_feature)) {
# Check for character name # Check for character name
if (typeof(private$categorical_feature) == "character") { if (is.character(private$categorical_feature)) {
cate_indices <- as.list(match(private$categorical_feature, private$colnames) - 1) cate_indices <- as.list(match(private$categorical_feature, private$colnames) - 1)
......
...@@ -138,8 +138,9 @@ lgb.cv <- function(params = list(), ...@@ -138,8 +138,9 @@ lgb.cv <- function(params = list(),
begin_iteration <- predictor$current_iter() + 1 begin_iteration <- predictor$current_iter() + 1
} }
# Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one # Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one
if (sum(names(params) %in% c("num_iterations", "num_iteration", "num_tree", "num_trees", "num_round", "num_rounds")) > 0) { n_trees <- c("num_iterations", "num_iteration", "num_tree", "num_trees", "num_round", "num_rounds")
end_iteration <- begin_iteration + params[[which(names(params) %in% c("num_iterations", "num_iteration", "num_tree", "num_trees", "num_round", "num_rounds"))[1]]] - 1 if (any(names(params) %in% n_trees)) {
end_iteration <- begin_iteration + params[[which(names(params) %in% n_trees)[1]]] - 1
} else { } else {
end_iteration <- begin_iteration + nrounds - 1 end_iteration <- begin_iteration + nrounds - 1
} }
...@@ -180,7 +181,7 @@ lgb.cv <- function(params = list(), ...@@ -180,7 +181,7 @@ lgb.cv <- function(params = list(),
if (!is.null(folds)) { if (!is.null(folds)) {
# Check for list of folds or for single value # Check for list of folds or for single value
if (!is.list(folds) | length(folds) < 2) { if (!is.list(folds) || length(folds) < 2) {
stop(sQuote("folds"), " must be a list with 2 or more elements that are vectors of indices for each CV-fold") stop(sQuote("folds"), " must be a list with 2 or more elements that are vectors of indices for each CV-fold")
} }
...@@ -205,7 +206,7 @@ lgb.cv <- function(params = list(), ...@@ -205,7 +206,7 @@ lgb.cv <- function(params = list(),
} }
# Add printing log callback # Add printing log callback
if (verbose > 0 & eval_freq > 0) { if (verbose > 0 && eval_freq > 0) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq)) callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
} }
...@@ -215,9 +216,10 @@ lgb.cv <- function(params = list(), ...@@ -215,9 +216,10 @@ lgb.cv <- function(params = list(),
} }
# Check for early stopping passed as parameter when adding early stopping callback # Check for early stopping passed as parameter when adding early stopping callback
if (sum(names(params) %in% c("early_stopping_round", "early_stopping_rounds", "early_stopping")) > 0) { early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping")
if (params[[which(names(params) %in% c("early_stopping_round", "early_stopping_rounds", "early_stopping"))[1]]] > 0) { if (any(names(params) %in% eary_stop)) {
callbacks <- add.cb(callbacks, cb.early.stop(params[[which(names(params) %in% c("early_stopping_round", "early_stopping_rounds", "early_stopping"))[1]]], verbose = verbose)) if (params[[which(names(params) %in% early_stop)[1]]] > 0) {
callbacks <- add.cb(callbacks, cb.early.stop(params[[which(names(params) %in% early_stop)[1]]], verbose = verbose))
} }
} else { } else {
if (!is.null(early_stopping_rounds)) { if (!is.null(early_stopping_rounds)) {
...@@ -246,7 +248,7 @@ lgb.cv <- function(params = list(), ...@@ -246,7 +248,7 @@ lgb.cv <- function(params = list(),
} else { } else {
bst_folds <- lapply(seq_along(folds), function(k) { bst_folds <- lapply(seq_along(folds), function(k) {
dtest <- slice(data, folds[[k]]$fold) dtest <- slice(data, folds[[k]]$fold)
dtrain <- slice(data, (1:nrow(data))[-folds[[k]]$fold]) dtrain <- slice(data, (seq_len(nrow(data)))[-folds[[k]]$fold])
setinfo(dtrain, "weight", getinfo(data, "weight")[-folds[[k]]$fold]) setinfo(dtrain, "weight", getinfo(data, "weight")[-folds[[k]]$fold])
setinfo(dtrain, "init_score", getinfo(data, "init_score")[-folds[[k]]$fold]) setinfo(dtrain, "init_score", getinfo(data, "init_score")[-folds[[k]]$fold])
setinfo(dtrain, "group", getinfo(data, "group")[-folds[[k]]$group]) setinfo(dtrain, "group", getinfo(data, "group")[-folds[[k]]$group])
...@@ -270,7 +272,7 @@ lgb.cv <- function(params = list(), ...@@ -270,7 +272,7 @@ lgb.cv <- function(params = list(),
env$end_iteration <- end_iteration env$end_iteration <- end_iteration
# Start training model using number of iterations to start and end with # Start training model using number of iterations to start and end with
for (i in seq(from = begin_iteration, to = end_iteration)) { for (i in seq.int(from = begin_iteration, to = end_iteration)) {
# Overwrite iteration in environment # Overwrite iteration in environment
env$iteration <- i env$iteration <- i
...@@ -320,7 +322,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) { ...@@ -320,7 +322,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
if (is.null(group)) { if (is.null(group)) {
# Shuffle # Shuffle
rnd_idx <- sample(seq_len(nrows)) rnd_idx <- sample.int(nrows)
# Request stratified folds # Request stratified folds
if (isTRUE(stratified) && params$objective %in% c("binary", "multiclass") && length(label) == length(rnd_idx)) { if (isTRUE(stratified) && params$objective %in% c("binary", "multiclass") && length(label) == length(rnd_idx)) {
...@@ -335,10 +337,10 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) { ...@@ -335,10 +337,10 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
folds <- list() folds <- list()
# Loop through each fold # Loop through each fold
for (i in 1:nfold) { for (i in seq_len(nfold)) {
kstep <- length(rnd_idx) %/% (nfold - i + 1) kstep <- length(rnd_idx) %/% (nfold - i + 1)
folds[[i]] <- rnd_idx[seq_len(kstep)] folds[[i]] <- rnd_idx[seq_len(kstep)]
rnd_idx <- rnd_idx[-(seq_len(kstep))] rnd_idx <- rnd_idx[-seq_len(kstep)]
} }
} }
...@@ -351,21 +353,20 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) { ...@@ -351,21 +353,20 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
} }
# Degroup the groups # Degroup the groups
ungrouped <- inverse.rle(list(lengths = group, values = 1:length(group))) ungrouped <- inverse.rle(list(lengths = group, values = seq_along(group)))
# Can't stratify, shuffle # Can't stratify, shuffle
rnd_idx <- sample(seq_len(length(group))) rnd_idx <- sample.int(length(group))
# Make simple non-stratified folds # Make simple non-stratified folds
folds <- list() folds <- list()
# Loop through each fold # Loop through each fold
for (i in 1:nfold) { for (i in seq_len(nfold)) {
kstep <- length(rnd_idx) %/% (nfold - i + 1) kstep <- length(rnd_idx) %/% (nfold - i + 1)
folds[[i]] <- list() folds[[i]] <- list(fold = which(ungrouped %in% rnd_idx[seq_len(kstep)]),
folds[[i]][["fold"]] <- which(ungrouped %in% rnd_idx[1:kstep]) group = rnd_idx[seq_len(kstep)])
folds[[i]][["group"]] <- rnd_idx[1:kstep] rnd_idx <- rnd_idx[-seq_len(kstep)]
rnd_idx <- rnd_idx[-(1:kstep)]
} }
} }
...@@ -390,11 +391,11 @@ lgb.stratified.folds <- function(y, k = 10) { ...@@ -390,11 +391,11 @@ lgb.stratified.folds <- function(y, k = 10) {
## is too small, we just do regular unstratified CV ## is too small, we just do regular unstratified CV
if (is.numeric(y)) { if (is.numeric(y)) {
cuts <- floor(length(y) / k) cuts <- length(y) %/% k
if (cuts < 2) { cuts <- 2 } if (cuts < 2) { cuts <- 2 }
if (cuts > 5) { cuts <- 5 } if (cuts > 5) { cuts <- 5 }
y <- cut(y, y <- cut(y,
unique(stats::quantile(y, probs = seq(0, 1, length = cuts))), unique(stats::quantile(y, probs = seq.int(0, 1, length.out = cuts))),
include.lowest = TRUE) include.lowest = TRUE)
} }
...@@ -420,7 +421,7 @@ lgb.stratified.folds <- function(y, k = 10) { ...@@ -420,7 +421,7 @@ lgb.stratified.folds <- function(y, k = 10) {
## Add enough random integers to get length(seqVector) == numInClass[i] ## Add enough random integers to get length(seqVector) == numInClass[i]
if (numInClass[i] %% k > 0) { if (numInClass[i] %% k > 0) {
seqVector <- c(seqVector, sample(seq_len(k), numInClass[i] %% k)) seqVector <- c(seqVector, sample.int(k, numInClass[i] %% k))
} }
## Shuffle the integers for fold assignment and assign to this classes's data ## Shuffle the integers for fold assignment and assign to this classes's data
...@@ -436,7 +437,8 @@ lgb.stratified.folds <- function(y, k = 10) { ...@@ -436,7 +437,8 @@ lgb.stratified.folds <- function(y, k = 10) {
# Return data # Return data
out <- split(seq(along = y), foldVector) out <- split(seq(along = y), foldVector)
`names<-`(out, NULL) names(out) <- NULL
out
} }
lgb.merge.cv.result <- function(msg, showsd = TRUE) { lgb.merge.cv.result <- function(msg, showsd = TRUE) {
......
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
lgb.importance <- function(model, percentage = TRUE) { lgb.importance <- function(model, percentage = TRUE) {
# Check if model is a lightgbm model # Check if model is a lightgbm model
if (!any(class(model) == "lgb.Booster")) { if (!inherits(model, "lgb.Booster")) {
stop("'model' has to be an object of class lgb.Booster") stop("'model' has to be an object of class lgb.Booster")
} }
...@@ -48,7 +48,7 @@ lgb.importance <- function(model, percentage = TRUE) { ...@@ -48,7 +48,7 @@ lgb.importance <- function(model, percentage = TRUE) {
# Extract elements # Extract elements
tree_imp <- tree_dt %>% tree_imp <- tree_dt %>%
magrittr::extract(., magrittr::extract(.,
i = is.na(split_index) == FALSE, i = ! is.na(split_index),
j = .(Gain = sum(split_gain), Cover = sum(internal_count), Frequency = .N), j = .(Gain = sum(split_gain), Cover = sum(internal_count), Frequency = .N),
by = "split_feature") %T>% by = "split_feature") %T>%
data.table::setnames(., old = "split_feature", new = "Feature") %>% data.table::setnames(., old = "split_feature", new = "Feature") %>%
......
...@@ -66,7 +66,7 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) { ...@@ -66,7 +66,7 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
# Lookup sequence # Lookup sequence
tree_dt[, split_feature := Lookup(split_feature, tree_dt[, split_feature := Lookup(split_feature,
seq(0, parsed_json_model$max_feature_idx, by = 1), seq.int(from = 0, to = parsed_json_model$max_feature_idx),
parsed_json_model$feature_names)] parsed_json_model$feature_names)]
# Return tree # Return tree
......
...@@ -46,7 +46,7 @@ lgb.plot.importance <- function(tree_imp, ...@@ -46,7 +46,7 @@ lgb.plot.importance <- function(tree_imp,
top_n <- min(top_n, nrow(tree_imp)) top_n <- min(top_n, nrow(tree_imp))
# Parse importance # Parse importance
tree_imp <- tree_imp[order(abs(get(measure)), decreasing = TRUE),][1:top_n,] tree_imp <- tree_imp[order(abs(get(measure)), decreasing = TRUE),][seq_len(top_n),]
# Attempt to setup a correct cex # Attempt to setup a correct cex
if (is.null(cex)) { if (is.null(cex)) {
......
...@@ -66,7 +66,7 @@ lgb.plot.interpretation <- function(tree_interpretation_dt, ...@@ -66,7 +66,7 @@ lgb.plot.interpretation <- function(tree_interpretation_dt,
} else { } else {
# More than one class, shape data first # More than one class, shape data first
layout_mat <- matrix(seq(1, cols * ceiling(num_class / cols)), layout_mat <- matrix(seq.int(to = cols * ceiling(num_class / cols)),
ncol = cols, nrow = ceiling(num_class / cols)) ncol = cols, nrow = ceiling(num_class / cols))
# Shape output # Shape output
...@@ -93,7 +93,7 @@ multiple.tree.plot.interpretation <- function(tree_interpretation, ...@@ -93,7 +93,7 @@ multiple.tree.plot.interpretation <- function(tree_interpretation,
cex) { cex) {
# Parse tree # Parse tree
tree_interpretation <- tree_interpretation[order(abs(Contribution), decreasing = TRUE),][1:min(top_n, .N),] tree_interpretation <- tree_interpretation[order(abs(Contribution), decreasing = TRUE),][seq_len(min(top_n, .N)),]
# Attempt to setup a correct cex # Attempt to setup a correct cex
if (is.null(cex)) { if (is.null(cex)) {
......
...@@ -44,10 +44,10 @@ ...@@ -44,10 +44,10 @@
lgb.prepare2 <- function(data) { lgb.prepare2 <- function(data) {
# data.table not behaving like data.frame # data.table not behaving like data.frame
if ("data.table" %in% class(data)) { if (inherits(data, "data.table")) {
# Get data classes # Get data classes
list_classes <- sapply(data, class) list_classes <- vapply(data, class, character(1))
# Convert characters to factors only (we can change them to numeric after) # Convert characters to factors only (we can change them to numeric after)
is_char <- which(list_classes == "character") is_char <- which(list_classes == "character")
...@@ -64,10 +64,10 @@ lgb.prepare2 <- function(data) { ...@@ -64,10 +64,10 @@ lgb.prepare2 <- function(data) {
} else { } else {
# Default routine (data.frame) # Default routine (data.frame)
if ("data.frame" %in% class(data)) { if (inherits(data, "data.frame")) {
# Get data classes # Get data classes
list_classes <- sapply(data, class) list_classes <- vapply(data, class, character(1))
# Convert characters to factors to numeric (integer is more efficient actually) # Convert characters to factors to numeric (integer is more efficient actually)
is_char <- which(list_classes == "character") is_char <- which(list_classes == "character")
......
...@@ -72,7 +72,7 @@ ...@@ -72,7 +72,7 @@
lgb.prepare_rules <- function(data, rules = NULL) { lgb.prepare_rules <- function(data, rules = NULL) {
# data.table not behaving like data.frame # data.table not behaving like data.frame
if ("data.table" %in% class(data)) { if (inherits(data, "data.table")) {
# Must use existing rules # Must use existing rules
if (!is.null(rules)) { if (!is.null(rules)) {
...@@ -88,7 +88,7 @@ lgb.prepare_rules <- function(data, rules = NULL) { ...@@ -88,7 +88,7 @@ lgb.prepare_rules <- function(data, rules = NULL) {
} else { } else {
# Get data classes # Get data classes
list_classes <- sapply(data, class) list_classes <- vapply(data, class, character(1))
# Map characters/factors # Map characters/factors
is_fix <- which(list_classes %in% c("character", "factor")) is_fix <- which(list_classes %in% c("character", "factor"))
...@@ -104,10 +104,10 @@ lgb.prepare_rules <- function(data, rules = NULL) { ...@@ -104,10 +104,10 @@ lgb.prepare_rules <- function(data, rules = NULL) {
mini_data <- data[[i]] mini_data <- data[[i]]
# Get unique values # Get unique values
if (class(mini_data) == "factor") { if (is.factor(mini_data)) {
mini_unique <- levels(mini_data) # Factor mini_unique <- levels(mini_data) # Factor
mini_numeric <- numeric(length(mini_unique)) mini_numeric <- numeric(length(mini_unique))
mini_numeric[1:length(mini_unique)] <- 1:length(mini_unique) # Respect ordinal if needed mini_numeric[seq_along(mini_unique)] <- seq_along(mini_unique) # Respect ordinal if needed
} else { } else {
mini_unique <- as.factor(unique(mini_data)) # Character mini_unique <- as.factor(unique(mini_data)) # Character
mini_numeric <- as.numeric(mini_unique) # No respect of ordinality mini_numeric <- as.numeric(mini_unique) # No respect of ordinality
...@@ -143,10 +143,10 @@ lgb.prepare_rules <- function(data, rules = NULL) { ...@@ -143,10 +143,10 @@ lgb.prepare_rules <- function(data, rules = NULL) {
} else { } else {
# Default routine (data.frame) # Default routine (data.frame)
if ("data.frame" %in% class(data)) { if (inherits(data, "data.frame")) {
# Get data classes # Get data classes
list_classes <- sapply(data, class) list_classes <- vapply(data, class, character(1))
# Map characters/factors # Map characters/factors
is_fix <- which(list_classes %in% c("character", "factor")) is_fix <- which(list_classes %in% c("character", "factor"))
...@@ -162,10 +162,10 @@ lgb.prepare_rules <- function(data, rules = NULL) { ...@@ -162,10 +162,10 @@ lgb.prepare_rules <- function(data, rules = NULL) {
mini_data <- data[[i]] mini_data <- data[[i]]
# Get unique values # Get unique values
if (class(mini_data) == "factor") { if (is.factor(mini_data)) {
mini_unique <- levels(mini_data) # Factor mini_unique <- levels(mini_data) # Factor
mini_numeric <- numeric(length(mini_unique)) mini_numeric <- numeric(length(mini_unique))
mini_numeric[1:length(mini_unique)] <- 1:length(mini_unique) # Respect ordinal if needed mini_numeric[seq_along(mini_unique)] <- seq_along(mini_unique) # Respect ordinal if needed
} else { } else {
mini_unique <- as.factor(unique(mini_data)) # Character mini_unique <- as.factor(unique(mini_data)) # Character
mini_numeric <- as.numeric(mini_unique) # No respect of ordinality mini_numeric <- as.numeric(mini_unique) # No respect of ordinality
......
...@@ -72,7 +72,7 @@ ...@@ -72,7 +72,7 @@
lgb.prepare_rules2 <- function(data, rules = NULL) { lgb.prepare_rules2 <- function(data, rules = NULL) {
# data.table not behaving like data.frame # data.table not behaving like data.frame
if ("data.table" %in% class(data)) { if (inherits(data, "data.table")) {
# Must use existing rules # Must use existing rules
if (!is.null(rules)) { if (!is.null(rules)) {
...@@ -88,7 +88,7 @@ lgb.prepare_rules2 <- function(data, rules = NULL) { ...@@ -88,7 +88,7 @@ lgb.prepare_rules2 <- function(data, rules = NULL) {
} else { } else {
# Get data classes # Get data classes
list_classes <- sapply(data, class) list_classes <- vapply(data, class, character(1))
# Map characters/factors # Map characters/factors
is_fix <- which(list_classes %in% c("character", "factor")) is_fix <- which(list_classes %in% c("character", "factor"))
...@@ -104,9 +104,9 @@ lgb.prepare_rules2 <- function(data, rules = NULL) { ...@@ -104,9 +104,9 @@ lgb.prepare_rules2 <- function(data, rules = NULL) {
mini_data <- data[[i]] mini_data <- data[[i]]
# Get unique values # Get unique values
if (class(mini_data) == "factor") { if (is.factor(mini_data)) {
mini_unique <- levels(mini_data) # Factor mini_unique <- levels(mini_data) # Factor
mini_numeric <- 1:length(mini_unique) # Respect ordinal if needed mini_numeric <- seq_along(mini_unique) # Respect ordinal if needed
} else { } else {
mini_unique <- as.factor(unique(mini_data)) # Character mini_unique <- as.factor(unique(mini_data)) # Character
mini_numeric <- as.integer(mini_unique) # No respect of ordinality mini_numeric <- as.integer(mini_unique) # No respect of ordinality
...@@ -142,10 +142,10 @@ lgb.prepare_rules2 <- function(data, rules = NULL) { ...@@ -142,10 +142,10 @@ lgb.prepare_rules2 <- function(data, rules = NULL) {
} else { } else {
# Default routine (data.frame) # Default routine (data.frame)
if ("data.frame" %in% class(data)) { if (inherits(data, "data.frame")) {
# Get data classes # Get data classes
list_classes <- sapply(data, class) list_classes <- vapply(data, class, character(1))
# Map characters/factors # Map characters/factors
is_fix <- which(list_classes %in% c("character", "factor")) is_fix <- which(list_classes %in% c("character", "factor"))
...@@ -161,9 +161,9 @@ lgb.prepare_rules2 <- function(data, rules = NULL) { ...@@ -161,9 +161,9 @@ lgb.prepare_rules2 <- function(data, rules = NULL) {
mini_data <- data[[i]] mini_data <- data[[i]]
# Get unique values # Get unique values
if (class(mini_data) == "factor") { if (is.factor(mini_data)) {
mini_unique <- levels(mini_data) # Factor mini_unique <- levels(mini_data) # Factor
mini_numeric <- 1:length(mini_unique) # Respect ordinal if needed mini_numeric <- seq_along(mini_unique) # Respect ordinal if needed
} else { } else {
mini_unique <- as.factor(unique(mini_data)) # Character mini_unique <- as.factor(unique(mini_data)) # Character
mini_numeric <- as.integer(mini_unique) # No respect of ordinality mini_numeric <- as.integer(mini_unique) # No respect of ordinality
......
...@@ -113,8 +113,9 @@ lgb.train <- function(params = list(), ...@@ -113,8 +113,9 @@ lgb.train <- function(params = list(),
begin_iteration <- predictor$current_iter() + 1 begin_iteration <- predictor$current_iter() + 1
} }
# Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one # Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one
if (sum(names(params) %in% c("num_iterations", "num_iteration", "num_tree", "num_trees", "num_round", "num_rounds")) > 0) { n_rounds <- c("num_iterations", "num_iteration", "num_tree", "num_trees", "num_round", "num_rounds")
end_iteration <- begin_iteration + params[[which(names(params) %in% c("num_iterations", "num_iteration", "num_tree", "num_trees", "num_round", "num_rounds"))[1]]] - 1 if (any(names(params) %in% n_rounds)) {
end_iteration <- begin_iteration + params[[which(names(params) %in% n_rounds)[1]]] - 1
} else { } else {
end_iteration <- begin_iteration + nrounds - 1 end_iteration <- begin_iteration + nrounds - 1
} }
...@@ -131,7 +132,7 @@ lgb.train <- function(params = list(), ...@@ -131,7 +132,7 @@ lgb.train <- function(params = list(),
# One or more validation dataset # One or more validation dataset
# Check for list as input and type correctness by object # Check for list as input and type correctness by object
if (!is.list(valids) || !all(sapply(valids, lgb.is.Dataset))) { if (!is.list(valids) || !all(vapply(valids, lgb.is.Dataset, logical(1)))) {
stop("lgb.train: valids must be a list of lgb.Dataset elements") stop("lgb.train: valids must be a list of lgb.Dataset elements")
} }
...@@ -192,19 +193,20 @@ lgb.train <- function(params = list(), ...@@ -192,19 +193,20 @@ lgb.train <- function(params = list(),
} }
# Add printing log callback # Add printing log callback
if (verbose > 0 & eval_freq > 0) { if (verbose > 0 && eval_freq > 0) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq)) callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
} }
# Add evaluation log callback # Add evaluation log callback
if (record & length(valids) > 0) { if (record && length(valids) > 0) {
callbacks <- add.cb(callbacks, cb.record.evaluation()) callbacks <- add.cb(callbacks, cb.record.evaluation())
} }
# Check for early stopping passed as parameter when adding early stopping callback # Check for early stopping passed as parameter when adding early stopping callback
if (sum(names(params) %in% c("early_stopping_round", "early_stopping_rounds", "early_stopping")) > 0) { early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping")
if (params[[which(names(params) %in% c("early_stopping_round", "early_stopping_rounds", "early_stopping"))[1]]] > 0) { if (any(names(params) %in% early_stop)) {
callbacks <- add.cb(callbacks, cb.early.stop(params[[which(names(params) %in% c("early_stopping_round", "early_stopping_rounds", "early_stopping"))[1]]], verbose = verbose)) if (params[[which(names(params) %in% early_stop)[1]]] > 0) {
callbacks <- add.cb(callbacks, cb.early.stop(params[[which(names(params) %in% early_stop)[1]]], verbose = verbose))
} }
} else { } else {
if (!is.null(early_stopping_rounds)) { if (!is.null(early_stopping_rounds)) {
...@@ -231,7 +233,7 @@ lgb.train <- function(params = list(), ...@@ -231,7 +233,7 @@ lgb.train <- function(params = list(),
env$end_iteration <- end_iteration env$end_iteration <- end_iteration
# Start training model using number of iterations to start and end with # Start training model using number of iterations to start and end with
for (i in seq(from = begin_iteration, to = end_iteration)) { for (i in seq.int(from = begin_iteration, to = end_iteration)) {
# Overwrite iteration in environment # Overwrite iteration in environment
env$iteration <- i env$iteration <- i
......
...@@ -42,8 +42,9 @@ lgb.unloader <- function(restore = TRUE, wipe = FALSE, envir = .GlobalEnv) { ...@@ -42,8 +42,9 @@ lgb.unloader <- function(restore = TRUE, wipe = FALSE, envir = .GlobalEnv) {
# Should we wipe variables? (lgb.Booster, lgb.Dataset) # Should we wipe variables? (lgb.Booster, lgb.Dataset)
if (wipe) { if (wipe) {
rm(list = ls(envir = envir)[which(sapply(ls(.GlobalEnv), function(x) {"lgb.Booster" %in% class(get(x, envir = envir))}))], envir = envir) boosters <- Filter(function(x) inherits(get(x, envir = envir), "lgb.Booster"), ls(envir = envir))
rm(list = ls(envir = envir)[which(sapply(ls(.GlobalEnv), function(x) {"lgb.Dataset" %in% class(get(x, envir = envir))}))], envir = envir) datasets <- Filter(function(x) inherits(get(x, envir = envir), "lgb.Dataset"), ls(envir = envir))
rm(list = c(boosters, datasets), envir = envir)
gc(verbose = FALSE) gc(verbose = FALSE)
} }
......
...@@ -43,7 +43,7 @@ saveRDS.lgb.Booster <- function(object, ...@@ -43,7 +43,7 @@ saveRDS.lgb.Booster <- function(object,
raw = TRUE) { raw = TRUE) {
# Check if object has a raw value (and if the user wants to store the raw) # Check if object has a raw value (and if the user wants to store the raw)
if (is.na(object$raw) & (raw)) { if (is.na(object$raw) && raw) {
# Save model # Save model
object$save() object$save()
......
...@@ -21,7 +21,7 @@ lgb.encode.char <- function(arr, len) { ...@@ -21,7 +21,7 @@ lgb.encode.char <- function(arr, len) {
lgb.call <- function(fun_name, ret, ...) { lgb.call <- function(fun_name, ret, ...) {
# Set call state to a zero value # Set call state to a zero value
call_state <- as.integer(0L) call_state <- 0L
# Check for a ret call # Check for a ret call
if (!is.null(ret)) { if (!is.null(ret)) {
...@@ -51,7 +51,7 @@ lgb.call <- function(fun_name, ret, ...) { ...@@ -51,7 +51,7 @@ lgb.call <- function(fun_name, ret, ...) {
} }
# Return error # Return error
stop(paste0("api error: ", lgb.encode.char(err_msg, act_len))) stop("api error: ", lgb.encode.char(err_msg, act_len))
} }
...@@ -145,18 +145,8 @@ lgb.c_str <- function(x) { ...@@ -145,18 +145,8 @@ lgb.c_str <- function(x) {
lgb.check.r6.class <- function(object, name) { lgb.check.r6.class <- function(object, name) {
# Check for non-existence of R6 class # Check for non-existence of R6 class or named class
if (!("R6" %in% class(object))) { all(c("R6", name) %in% class(object))
return(FALSE)
}
# Check for non-existance of a named class
if (!(name %in% class(object))) {
return(FALSE)
}
# Return default value
return(TRUE)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment