Commit 90fc72a1 authored by Laurae's avatar Laurae Committed by Guolin Ke
Browse files

[R-package] CV Improvements, fix #390 (#791)

* Force fork

* Better random CV + Group + Info handling

* Trailing comma missing

* Switch group fold to nrow

* Switch comparison from non-atomic to atomic
parent 666191b2
......@@ -194,6 +194,7 @@ lgb.cv <- function(params = list(),
nrow(data),
stratified,
getinfo(data, "label"),
getinfo(data, "group"),
params)
}
......@@ -218,14 +219,35 @@ lgb.cv <- function(params = list(),
# Categorize callbacks
cb <- categorize.callbacks(callbacks)
# Construct booster using a list apply
# Construct booster using a list apply, check if requires group or not
if (!is.list(folds[[1]])) {
bst_folds <- lapply(seq_along(folds), function(k) {
dtest <- slice(data, folds[[k]])
dtrain <- slice(data, unlist(folds[-k]))
setinfo(dtrain, "weight", getinfo(data, "weight")[-folds[[k]]])
setinfo(dtrain, "init_score", getinfo(data, "init_score")[-folds[[k]]])
setinfo(dtest, "weight", getinfo(data, "weight")[folds[[k]]])
setinfo(dtest, "init_score", getinfo(data, "init_score")[folds[[k]]])
booster <- Booster$new(params, dtrain)
booster$add_valid(dtest, "valid")
list(booster = booster)
})
} else {
bst_folds <- lapply(seq_along(folds), function(k) {
dtest <- slice(data, folds[[k]]$fold)
dtrain <- slice(data, (1:nrow(data))[-folds[[k]]$fold])
setinfo(dtrain, "weight", getinfo(data, "weight")[-folds[[k]]$fold])
setinfo(dtrain, "init_score", getinfo(data, "init_score")[-folds[[k]]$fold])
setinfo(dtrain, "group", getinfo(data, "group")[-folds[[k]]$group])
setinfo(dtest, "weight", getinfo(data, "weight")[folds[[k]]$fold])
setinfo(dtest, "init_score", getinfo(data, "init_score")[folds[[k]]$fold])
setinfo(dtest, "group", getinfo(data, "group")[folds[[k]]$group])
booster <- Booster$new(params, dtrain)
booster$add_valid(dtest, "valid")
list(booster = booster)
})
}
# Create new booster
cv_booster <- CVBooster$new(bst_folds)
......@@ -281,12 +303,10 @@ lgb.cv <- function(params = list(),
}
# Generates random (stratified if needed) CV folds
generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
# Cannot do it for rank
if (exists('objective', where = params) && is.character(params$objective) && params$objective == "lambdarank") {
stop("\n\tAutomatic generation of CV-folds is not implemented for lambdarank!\n", "\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n")
}
# Check for group existence
if (is.null(group)) {
# Shuffle
rnd_idx <- sample(seq_len(nrows))
......@@ -301,15 +321,41 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
} else {
# Make simple non-stratified folds
kstep <- length(rnd_idx) %/% nfold
folds <- list()
# Loop through each fold
for (i in seq_len(nfold - 1)) {
for (i in 1:nfold) {
kstep <- length(rnd_idx) %/% (nfold - i + 1)
folds[[i]] <- rnd_idx[seq_len(kstep)]
rnd_idx <- rnd_idx[-(seq_len(kstep))]
}
folds[[nfold]] <- rnd_idx
}
} else {
# When doing group, stratified is not possible (only random selection)
if (nfold > length(group)) {
stop("\n\tYou requested too many folds for the number of available groups.\n")
}
# Degroup the groups
ungrouped <- inverse.rle(list(lengths = group, values = 1:length(group)))
# Can't stratify, shuffle
rnd_idx <- sample(seq_len(length(group)))
# Make simple non-stratified folds
folds <- list()
# Loop through each fold
for (i in 1:nfold) {
kstep <- length(rnd_idx) %/% (nfold - i + 1)
folds[[i]] <- list()
folds[[i]][["fold"]] <- which(ungrouped %in% rnd_idx[1:kstep])
folds[[i]][["group"]] <- rnd_idx[1:kstep]
rnd_idx <- rnd_idx[-(1:kstep)]
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment