Unverified Commit d4629727 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] introduce Dataset methods set_field() and get_field() (#4571)



* [R-package] introduce Dataset set_field() and get_field()

* fix incorrect fields

* update pkgdown

* fix example

* fix another example

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* update docs
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 74c7904b
...@@ -58,4 +58,4 @@ Imports: ...@@ -58,4 +58,4 @@ Imports:
utils utils
SystemRequirements: SystemRequirements:
C++11 C++11
RoxygenNote: 7.1.1 RoxygenNote: 7.1.2
...@@ -3,10 +3,13 @@ ...@@ -3,10 +3,13 @@
S3method("dimnames<-",lgb.Dataset) S3method("dimnames<-",lgb.Dataset)
S3method(dim,lgb.Dataset) S3method(dim,lgb.Dataset)
S3method(dimnames,lgb.Dataset) S3method(dimnames,lgb.Dataset)
S3method(get_field,lgb.Dataset)
S3method(getinfo,lgb.Dataset) S3method(getinfo,lgb.Dataset)
S3method(predict,lgb.Booster) S3method(predict,lgb.Booster)
S3method(set_field,lgb.Dataset)
S3method(setinfo,lgb.Dataset) S3method(setinfo,lgb.Dataset)
S3method(slice,lgb.Dataset) S3method(slice,lgb.Dataset)
export(get_field)
export(getinfo) export(getinfo)
export(lgb.Dataset) export(lgb.Dataset)
export(lgb.Dataset.construct) export(lgb.Dataset.construct)
...@@ -30,6 +33,7 @@ export(lgb.unloader) ...@@ -30,6 +33,7 @@ export(lgb.unloader)
export(lightgbm) export(lightgbm)
export(readRDS.lgb.Booster) export(readRDS.lgb.Booster)
export(saveRDS.lgb.Booster) export(saveRDS.lgb.Booster)
export(set_field)
export(setinfo) export(setinfo)
export(slice) export(slice)
import(methods) import(methods)
......
...@@ -335,14 +335,17 @@ Dataset <- R6::R6Class( ...@@ -335,14 +335,17 @@ Dataset <- R6::R6Class(
for (i in seq_along(private$info)) { for (i in seq_along(private$info)) {
p <- private$info[i] p <- private$info[i]
self$setinfo(name = names(p), info = p[[1L]]) self$set_field(
field_name = names(p)
, data = p[[1L]]
)
} }
} }
# Get label information existence # Get label information existence
if (is.null(self$getinfo(name = "label"))) { if (is.null(self$get_field(field_name = "label"))) {
stop("lgb.Dataset.construct: label should be set") stop("lgb.Dataset.construct: label should be set")
} }
...@@ -452,19 +455,33 @@ Dataset <- R6::R6Class( ...@@ -452,19 +455,33 @@ Dataset <- R6::R6Class(
}, },
# Get information
getinfo = function(name) { getinfo = function(name) {
warning(paste0(
"Dataset$getinfo() is deprecated and will be removed in a future release. "
, "Use Dataset$get_field() instead."
))
return(
self$get_field(
field_name = name
)
)
},
get_field = function(field_name) {
# Check if attribute key is in the known attribute list # Check if attribute key is in the known attribute list
if (!is.character(name) || length(name) != 1L || !name %in% .INFO_KEYS()) { if (!is.character(field_name) || length(field_name) != 1L || !field_name %in% .INFO_KEYS()) {
stop("getinfo: name must one of the following: ", paste0(sQuote(.INFO_KEYS()), collapse = ", ")) stop(
"Dataset$get_field(): field_name must one of the following: "
, paste0(sQuote(.INFO_KEYS()), collapse = ", ")
)
} }
# Check for info name and handle # Check for info name and handle
if (is.null(private$info[[name]])) { if (is.null(private$info[[field_name]])) {
if (lgb.is.null.handle(x = private$handle)) { if (lgb.is.null.handle(x = private$handle)) {
stop("Cannot perform getinfo before constructing Dataset.") stop("Cannot perform Dataset$get_field() before constructing Dataset.")
} }
# Get field size of info # Get field size of info
...@@ -472,7 +489,7 @@ Dataset <- R6::R6Class( ...@@ -472,7 +489,7 @@ Dataset <- R6::R6Class(
.Call( .Call(
LGBM_DatasetGetFieldSize_R LGBM_DatasetGetFieldSize_R
, private$handle , private$handle
, name , field_name
, info_len , info_len
) )
...@@ -481,7 +498,7 @@ Dataset <- R6::R6Class( ...@@ -481,7 +498,7 @@ Dataset <- R6::R6Class(
# Get back fields # Get back fields
ret <- NULL ret <- NULL
ret <- if (name == "group") { ret <- if (field_name == "group") {
integer(info_len) # Integer integer(info_len) # Integer
} else { } else {
numeric(info_len) # Numeric numeric(info_len) # Numeric
...@@ -490,47 +507,62 @@ Dataset <- R6::R6Class( ...@@ -490,47 +507,62 @@ Dataset <- R6::R6Class(
.Call( .Call(
LGBM_DatasetGetField_R LGBM_DatasetGetField_R
, private$handle , private$handle
, name , field_name
, ret , ret
) )
private$info[[name]] <- ret private$info[[field_name]] <- ret
} }
} }
return(private$info[[name]]) return(private$info[[field_name]])
}, },
# Set information
setinfo = function(name, info) { setinfo = function(name, info) {
warning(paste0(
"Dataset$setinfo() is deprecated and will be removed in a future release. "
, "Use Dataset$set_field() instead."
))
return(
self$set_field(
field_name = name
, data = info
)
)
},
set_field = function(field_name, data) {
# Check if attribute key is in the known attribute list # Check if attribute key is in the known attribute list
if (!is.character(name) || length(name) != 1L || !name %in% .INFO_KEYS()) { if (!is.character(field_name) || length(field_name) != 1L || !field_name %in% .INFO_KEYS()) {
stop("setinfo: name must one of the following: ", paste0(sQuote(.INFO_KEYS()), collapse = ", ")) stop(
"Dataset$set_field(): field_name must one of the following: "
, paste0(sQuote(.INFO_KEYS()), collapse = ", ")
)
} }
# Check for type of information # Check for type of information
info <- if (name == "group") { data <- if (field_name == "group") {
as.integer(info) # Integer as.integer(data) # Integer
} else { } else {
as.numeric(info) # Numeric as.numeric(data) # Numeric
} }
# Store information privately # Store information privately
private$info[[name]] <- info private$info[[field_name]] <- data
if (!lgb.is.null.handle(x = private$handle) && !is.null(info)) { if (!lgb.is.null.handle(x = private$handle) && !is.null(data)) {
if (length(info) > 0L) { if (length(data) > 0L) {
.Call( .Call(
LGBM_DatasetSetField_R LGBM_DatasetSetField_R
, private$handle , private$handle
, name , field_name
, info , data
, length(info) , length(data)
) )
private$version <- private$version + 1L private$version <- private$version + 1L
...@@ -554,7 +586,7 @@ Dataset <- R6::R6Class( ...@@ -554,7 +586,7 @@ Dataset <- R6::R6Class(
, paste(names(additional_keyword_args), collapse = ", ") , paste(names(additional_keyword_args), collapse = ", ")
, ". These are ignored and should be removed. " , ". These are ignored and should be removed. "
, "To change the parameters of a Dataset produced by Dataset$slice(), use Dataset$set_params(). " , "To change the parameters of a Dataset produced by Dataset$slice(), use Dataset$set_params(). "
, "To modify attributes like 'init_score', use Dataset$setinfo(). " , "To modify attributes like 'init_score', use Dataset$set_field(). "
, "In future releases of lightgbm, this warning will become an error." , "In future releases of lightgbm, this warning will become an error."
)) ))
} }
...@@ -1110,7 +1142,7 @@ dimnames.lgb.Dataset <- function(x) { ...@@ -1110,7 +1142,7 @@ dimnames.lgb.Dataset <- function(x) {
#' #'
#' dsub <- lightgbm::slice(dtrain, seq_len(42L)) #' dsub <- lightgbm::slice(dtrain, seq_len(42L))
#' lgb.Dataset.construct(dsub) #' lgb.Dataset.construct(dsub)
#' labels <- lightgbm::getinfo(dsub, "label") #' labels <- lightgbm::get_field(dsub, "label")
#' } #' }
#' @export #' @export
slice <- function(dataset, ...) { slice <- function(dataset, ...) {
...@@ -1173,6 +1205,8 @@ getinfo <- function(dataset, ...) { ...@@ -1173,6 +1205,8 @@ getinfo <- function(dataset, ...) {
#' @export #' @export
getinfo.lgb.Dataset <- function(dataset, name, ...) { getinfo.lgb.Dataset <- function(dataset, name, ...) {
warning("Calling getinfo() on a lgb.Dataset is deprecated. Use get_field() instead.")
additional_args <- list(...) additional_args <- list(...)
if (length(additional_args) > 0L) { if (length(additional_args) > 0L) {
warning(paste0( warning(paste0(
...@@ -1187,7 +1221,7 @@ getinfo.lgb.Dataset <- function(dataset, name, ...) { ...@@ -1187,7 +1221,7 @@ getinfo.lgb.Dataset <- function(dataset, name, ...) {
stop("getinfo.lgb.Dataset: input dataset should be an lgb.Dataset object") stop("getinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
} }
return(dataset$getinfo(name = name)) return(dataset$get_field(field_name = name))
} }
...@@ -1236,6 +1270,8 @@ setinfo <- function(dataset, ...) { ...@@ -1236,6 +1270,8 @@ setinfo <- function(dataset, ...) {
#' @export #' @export
setinfo.lgb.Dataset <- function(dataset, name, info, ...) { setinfo.lgb.Dataset <- function(dataset, name, info, ...) {
warning("Calling setinfo() on a lgb.Dataset is deprecated. Use set_field() instead.")
additional_args <- list(...) additional_args <- list(...)
if (length(additional_args) > 0L) { if (length(additional_args) > 0L) {
warning(paste0( warning(paste0(
...@@ -1250,7 +1286,102 @@ setinfo.lgb.Dataset <- function(dataset, name, info, ...) { ...@@ -1250,7 +1286,102 @@ setinfo.lgb.Dataset <- function(dataset, name, info, ...) {
stop("setinfo.lgb.Dataset: input dataset should be an lgb.Dataset object") stop("setinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
} }
return(invisible(dataset$setinfo(name = name, info = info))) return(invisible(dataset$set_field(field_name = name, data = info)))
}
#' @name get_field
#' @title Get one attribute of a \code{lgb.Dataset}
#' @description Get one attribute of a \code{lgb.Dataset}
#' @param dataset Object of class \code{lgb.Dataset}
#' @param field_name String with the name of the attribute to get. One of the following.
#' \itemize{
#' \item \code{label}: label lightgbm learns from ;
#' \item \code{weight}: to do a weight rescale ;
#' \item{\code{group}: used for learning-to-rank tasks. An integer vector describing how to
#' group rows together as ordered results from the same set of candidate results to be ranked.
#' For example, if you have a 100-document dataset with \code{group = c(10, 20, 40, 10, 10, 10)},
#' that means that you have 6 groups, where the first 10 records are in the first group,
#' records 11-30 are in the second group, etc.}
#' \item \code{init_score}: initial score is the base prediction lightgbm will boost from.
#' }
#' @return requested attribute
#'
#' @examples
#' \donttest{
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.construct(dtrain)
#'
#' labels <- lightgbm::get_field(dtrain, "label")
#' lightgbm::set_field(dtrain, "label", 1 - labels)
#'
#' labels2 <- lightgbm::get_field(dtrain, "label")
#' stopifnot(all(labels2 == 1 - labels))
#' }
#' @export
get_field <- function(dataset, field_name) {
UseMethod("get_field")
}
#' @rdname get_field
#' @export
get_field.lgb.Dataset <- function(dataset, field_name) {
# Check if dataset is not a dataset
if (!lgb.is.Dataset(x = dataset)) {
stop("get_field.lgb.Dataset(): input dataset should be an lgb.Dataset object")
}
return(dataset$get_field(field_name = field_name))
}
#' @name set_field
#' @title Set one attribute of a \code{lgb.Dataset} object
#' @description Set one attribute of a \code{lgb.Dataset}
#' @param dataset Object of class \code{lgb.Dataset}
#' @param field_name String with the name of the attribute to set. One of the following.
#' \itemize{
#' \item \code{label}: label lightgbm learns from ;
#' \item \code{weight}: to do a weight rescale ;
#' \item{\code{group}: used for learning-to-rank tasks. An integer vector describing how to
#' group rows together as ordered results from the same set of candidate results to be ranked.
#' For example, if you have a 100-document dataset with \code{group = c(10, 20, 40, 10, 10, 10)},
#' that means that you have 6 groups, where the first 10 records are in the first group,
#' records 11-30 are in the second group, etc.}
#' \item \code{init_score}: initial score is the base prediction lightgbm will boost from.
#' }
#' @param data The data for the field. See examples.
#' @return The \code{lgb.Dataset} you passed in.
#'
#' @examples
#' \donttest{
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.construct(dtrain)
#'
#' labels <- lightgbm::get_field(dtrain, "label")
#' lightgbm::set_field(dtrain, "label", 1 - labels)
#'
#' labels2 <- lightgbm::get_field(dtrain, "label")
#' stopifnot(all.equal(labels2, 1 - labels))
#' }
#' @export
set_field <- function(dataset, field_name, data) {
UseMethod("set_field")
}
#' @rdname set_field
#' @export
set_field.lgb.Dataset <- function(dataset, field_name, data) {
if (!lgb.is.Dataset(x = dataset)) {
stop("set_field.lgb.Dataset: input dataset should be an lgb.Dataset object")
}
return(invisible(dataset$set_field(field_name = field_name, data = data)))
} }
#' @name lgb.Dataset.set.categorical #' @name lgb.Dataset.set.categorical
......
...@@ -206,7 +206,7 @@ lgb.cv <- function(params = list() ...@@ -206,7 +206,7 @@ lgb.cv <- function(params = list()
) )
if (!is.null(weight)) { if (!is.null(weight)) {
data$setinfo(name = "weight", info = weight) data$set_field(field_name = "weight", data = weight)
} }
# Update parameters with parsed parameters # Update parameters with parsed parameters
...@@ -245,8 +245,8 @@ lgb.cv <- function(params = list() ...@@ -245,8 +245,8 @@ lgb.cv <- function(params = list()
nfold = nfold nfold = nfold
, nrows = nrow(data) , nrows = nrow(data)
, stratified = stratified , stratified = stratified
, label = getinfo(dataset = data, name = "label") , label = get_field(dataset = data, field_name = "label")
, group = getinfo(dataset = data, name = "group") , group = get_field(dataset = data, field_name = "group")
, params = params , params = params
) )
...@@ -320,8 +320,8 @@ lgb.cv <- function(params = list() ...@@ -320,8 +320,8 @@ lgb.cv <- function(params = list()
if (folds_have_group) { if (folds_have_group) {
test_indices <- folds[[k]]$fold test_indices <- folds[[k]]$fold
test_group_indices <- folds[[k]]$group test_group_indices <- folds[[k]]$group
test_groups <- getinfo(dataset = data, name = "group")[test_group_indices] test_groups <- get_field(dataset = data, field_name = "group")[test_group_indices]
train_groups <- getinfo(dataset = data, name = "group")[-test_group_indices] train_groups <- get_field(dataset = data, field_name = "group")[-test_group_indices]
} else { } else {
test_indices <- folds[[k]] test_indices <- folds[[k]]
} }
...@@ -330,28 +330,28 @@ lgb.cv <- function(params = list() ...@@ -330,28 +330,28 @@ lgb.cv <- function(params = list()
# set up test set # set up test set
indexDT <- data.table::data.table( indexDT <- data.table::data.table(
indices = test_indices indices = test_indices
, weight = getinfo(dataset = data, name = "weight")[test_indices] , weight = get_field(dataset = data, field_name = "weight")[test_indices]
, init_score = getinfo(dataset = data, name = "init_score")[test_indices] , init_score = get_field(dataset = data, field_name = "init_score")[test_indices]
) )
data.table::setorderv(x = indexDT, cols = "indices", order = 1L) data.table::setorderv(x = indexDT, cols = "indices", order = 1L)
dtest <- slice(data, indexDT$indices) dtest <- slice(data, indexDT$indices)
setinfo(dataset = dtest, name = "weight", info = indexDT$weight) set_field(dataset = dtest, field_name = "weight", data = indexDT$weight)
setinfo(dataset = dtest, name = "init_score", info = indexDT$init_score) set_field(dataset = dtest, field_name = "init_score", data = indexDT$init_score)
# set up training set # set up training set
indexDT <- data.table::data.table( indexDT <- data.table::data.table(
indices = train_indices indices = train_indices
, weight = getinfo(dataset = data, name = "weight")[train_indices] , weight = get_field(dataset = data, field_name = "weight")[train_indices]
, init_score = getinfo(dataset = data, name = "init_score")[train_indices] , init_score = get_field(dataset = data, field_name = "init_score")[train_indices]
) )
data.table::setorderv(x = indexDT, cols = "indices", order = 1L) data.table::setorderv(x = indexDT, cols = "indices", order = 1L)
dtrain <- slice(data, indexDT$indices) dtrain <- slice(data, indexDT$indices)
setinfo(dataset = dtrain, name = "weight", info = indexDT$weight) set_field(dataset = dtrain, field_name = "weight", data = indexDT$weight)
setinfo(dataset = dtrain, name = "init_score", info = indexDT$init_score) set_field(dataset = dtrain, field_name = "init_score", data = indexDT$init_score)
if (folds_have_group) { if (folds_have_group) {
setinfo(dataset = dtest, name = "group", info = test_groups) set_field(dataset = dtest, field_name = "group", data = test_groups)
setinfo(dataset = dtrain, name = "group", info = train_groups) set_field(dataset = dtrain, field_name = "group", data = train_groups)
} }
booster <- Booster$new(params = params, train_set = dtrain) booster <- Booster$new(params = params, train_set = dtrain)
......
...@@ -21,7 +21,11 @@ ...@@ -21,7 +21,11 @@
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train #' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label) #' dtrain <- lgb.Dataset(train$data, label = train$label)
#' setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label))) #' set_field(
#' dataset = dtrain
#' , field_name = "init_score"
#' , data = rep(Logit(mean(train$label)), length(train$label))
#' )
#' data(agaricus.test, package = "lightgbm") #' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test #' test <- agaricus.test
#' #'
......
...@@ -25,7 +25,11 @@ ...@@ -25,7 +25,11 @@
#' agaricus.train$data #' agaricus.train$data
#' , label = labels #' , label = labels
#' ) #' )
#' setinfo(dtrain, "init_score", rep(Logit(mean(labels)), length(labels))) #' set_field(
#' dataset = dtrain
#' , field_name = "init_score"
#' , data = rep(Logit(mean(labels)), length(labels))
#' )
#' #'
#' data(agaricus.test, package = "lightgbm") #' data(agaricus.test, package = "lightgbm")
#' #'
......
...@@ -147,8 +147,8 @@ bst <- lgb.train( ...@@ -147,8 +147,8 @@ bst <- lgb.train(
, valids = valids , valids = valids
) )
# information can be extracted from lgb.Dataset using getinfo # information can be extracted from lgb.Dataset using get_field()
label <- getinfo(dtest, "label") label <- get_field(dtest, "label")
pred <- predict(bst, test$data) pred <- predict(bst, test$data)
err <- as.numeric(sum(as.integer(pred > 0.5) != label)) / length(label) err <- as.numeric(sum(as.integer(pred > 0.5) != label)) / length(label)
print(paste("test-error=", err)) print(paste("test-error=", err))
...@@ -27,8 +27,8 @@ ptest <- predict(bst, agaricus.test$data, rawscore = TRUE) ...@@ -27,8 +27,8 @@ ptest <- predict(bst, agaricus.test$data, rawscore = TRUE)
# set the init_score property of dtrain and dtest # set the init_score property of dtrain and dtest
# base margin is the base prediction we will boost from # base margin is the base prediction we will boost from
setinfo(dtrain, "init_score", ptrain) set_field(dtrain, "init_score", ptrain)
setinfo(dtest, "init_score", ptest) set_field(dtest, "init_score", ptest)
print("This is result of boost from initial prediction") print("This is result of boost from initial prediction")
bst <- lgb.train( bst <- lgb.train(
......
...@@ -42,7 +42,7 @@ lgb.cv( ...@@ -42,7 +42,7 @@ lgb.cv(
print("Running cross validation, with cutomsized loss function") print("Running cross validation, with cutomsized loss function")
logregobj <- function(preds, dtrain) { logregobj <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label") labels <- get_field(dtrain, "label")
preds <- 1.0 / (1.0 + exp(-preds)) preds <- 1.0 / (1.0 + exp(-preds))
grad <- preds - labels grad <- preds - labels
hess <- preds * (1.0 - preds) hess <- preds * (1.0 - preds)
...@@ -55,7 +55,7 @@ logregobj <- function(preds, dtrain) { ...@@ -55,7 +55,7 @@ logregobj <- function(preds, dtrain) {
# For example, we are doing logistic loss, the prediction is score before logistic transformation # For example, we are doing logistic loss, the prediction is score before logistic transformation
# Keep this in mind when you use the customization, and maybe you need write customized evaluation function # Keep this in mind when you use the customization, and maybe you need write customized evaluation function
evalerror <- function(preds, dtrain) { evalerror <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label") labels <- get_field(dtrain, "label")
preds <- 1.0 / (1.0 + exp(-preds)) preds <- 1.0 / (1.0 + exp(-preds))
err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels) err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels)
return(list(name = "error", value = err, higher_better = FALSE)) return(list(name = "error", value = err, higher_better = FALSE))
......
...@@ -21,7 +21,7 @@ num_round <- 20L ...@@ -21,7 +21,7 @@ num_round <- 20L
# User define objective function, given prediction, return gradient and second order gradient # User define objective function, given prediction, return gradient and second order gradient
# This is loglikelihood loss # This is loglikelihood loss
logregobj <- function(preds, dtrain) { logregobj <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label") labels <- get_field(dtrain, "label")
preds <- 1.0 / (1.0 + exp(-preds)) preds <- 1.0 / (1.0 + exp(-preds))
grad <- preds - labels grad <- preds - labels
hess <- preds * (1.0 - preds) hess <- preds * (1.0 - preds)
...@@ -35,7 +35,7 @@ logregobj <- function(preds, dtrain) { ...@@ -35,7 +35,7 @@ logregobj <- function(preds, dtrain) {
# The built-in evaluation error assumes input is after logistic transformation # The built-in evaluation error assumes input is after logistic transformation
# Keep this in mind when you use the customization, and maybe you need write customized evaluation function # Keep this in mind when you use the customization, and maybe you need write customized evaluation function
evalerror <- function(preds, dtrain) { evalerror <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label") labels <- get_field(dtrain, "label")
err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels) err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels)
return(list(name = "error", value = err, higher_better = FALSE)) return(list(name = "error", value = err, higher_better = FALSE))
} }
......
...@@ -43,7 +43,7 @@ probs_builtin <- exp(preds_builtin) / rowSums(exp(preds_builtin)) ...@@ -43,7 +43,7 @@ probs_builtin <- exp(preds_builtin) / rowSums(exp(preds_builtin))
# User defined objective function, given prediction, return gradient and second order gradient # User defined objective function, given prediction, return gradient and second order gradient
custom_multiclass_obj <- function(preds, dtrain) { custom_multiclass_obj <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label") labels <- get_field(dtrain, "label")
# preds is a matrix with rows corresponding to samples and columns corresponding to choices # preds is a matrix with rows corresponding to samples and columns corresponding to choices
preds <- matrix(preds, nrow = length(labels)) preds <- matrix(preds, nrow = length(labels))
...@@ -73,7 +73,7 @@ custom_multiclass_obj <- function(preds, dtrain) { ...@@ -73,7 +73,7 @@ custom_multiclass_obj <- function(preds, dtrain) {
# define custom metric # define custom metric
custom_multiclass_metric <- function(preds, dtrain) { custom_multiclass_metric <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label") labels <- get_field(dtrain, "label")
preds <- matrix(preds, nrow = length(labels)) preds <- matrix(preds, nrow = length(labels))
preds <- preds - apply(preds, 1L, max) preds <- preds - apply(preds, 1L, max)
prob <- exp(preds) / rowSums(exp(preds)) prob <- exp(preds) / rowSums(exp(preds))
......
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/lgb.Dataset.R
\name{get_field}
\alias{get_field}
\alias{get_field.lgb.Dataset}
\title{Get one attribute of a \code{lgb.Dataset}}
\usage{
get_field(dataset, field_name)
\method{get_field}{lgb.Dataset}(dataset, field_name)
}
\arguments{
\item{dataset}{Object of class \code{lgb.Dataset}}
\item{field_name}{String with the name of the attribute to get. One of the following.
\itemize{
\item \code{label}: label lightgbm learns from ;
\item \code{weight}: to do a weight rescale ;
\item{\code{group}: used for learning-to-rank tasks. An integer vector describing how to
group rows together as ordered results from the same set of candidate results to be ranked.
For example, if you have a 100-document dataset with \code{group = c(10, 20, 40, 10, 10, 10)},
that means that you have 6 groups, where the first 10 records are in the first group,
records 11-30 are in the second group, etc.}
\item \code{init_score}: initial score is the base prediction lightgbm will boost from.
}}
}
\value{
requested attribute
}
\description{
Get one attribute of a \code{lgb.Dataset}
}
\examples{
\donttest{
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
lgb.Dataset.construct(dtrain)
labels <- lightgbm::get_field(dtrain, "label")
lightgbm::set_field(dtrain, "label", 1 - labels)
labels2 <- lightgbm::get_field(dtrain, "label")
stopifnot(all(labels2 == 1 - labels))
}
}
...@@ -34,7 +34,11 @@ Logit <- function(x) log(x / (1.0 - x)) ...@@ -34,7 +34,11 @@ Logit <- function(x) log(x / (1.0 - x))
data(agaricus.train, package = "lightgbm") data(agaricus.train, package = "lightgbm")
train <- agaricus.train train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label) dtrain <- lgb.Dataset(train$data, label = train$label)
setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label))) set_field(
dataset = dtrain
, field_name = "init_score"
, data = rep(Logit(mean(train$label)), length(train$label))
)
data(agaricus.test, package = "lightgbm") data(agaricus.test, package = "lightgbm")
test <- agaricus.test test <- agaricus.test
......
...@@ -44,7 +44,11 @@ dtrain <- lgb.Dataset( ...@@ -44,7 +44,11 @@ dtrain <- lgb.Dataset(
agaricus.train$data agaricus.train$data
, label = labels , label = labels
) )
setinfo(dtrain, "init_score", rep(Logit(mean(labels)), length(labels))) set_field(
dataset = dtrain
, field_name = "init_score"
, data = rep(Logit(mean(labels)), length(labels))
)
data(agaricus.test, package = "lightgbm") data(agaricus.test, package = "lightgbm")
......
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/lgb.Dataset.R
\name{set_field}
\alias{set_field}
\alias{set_field.lgb.Dataset}
\title{Set one attribute of a \code{lgb.Dataset} object}
\usage{
set_field(dataset, field_name, data)
\method{set_field}{lgb.Dataset}(dataset, field_name, data)
}
\arguments{
\item{dataset}{Object of class \code{lgb.Dataset}}
\item{field_name}{String with the name of the attribute to set. One of the following.
\itemize{
\item \code{label}: label lightgbm learns from ;
\item \code{weight}: to do a weight rescale ;
\item{\code{group}: used for learning-to-rank tasks. An integer vector describing how to
group rows together as ordered results from the same set of candidate results to be ranked.
For example, if you have a 100-document dataset with \code{group = c(10, 20, 40, 10, 10, 10)},
that means that you have 6 groups, where the first 10 records are in the first group,
records 11-30 are in the second group, etc.}
\item \code{init_score}: initial score is the base prediction lightgbm will boost from.
}}
\item{data}{The data for the field. See examples.}
}
\value{
The \code{lgb.Dataset} you passed in.
}
\description{
Set one attribute of a \code{lgb.Dataset}
}
\examples{
\donttest{
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
lgb.Dataset.construct(dtrain)
labels <- lightgbm::get_field(dtrain, "label")
lightgbm::set_field(dtrain, "label", 1 - labels)
labels2 <- lightgbm::get_field(dtrain, "label")
stopifnot(all.equal(labels2, 1 - labels))
}
}
...@@ -31,6 +31,6 @@ dtrain <- lgb.Dataset(train$data, label = train$label) ...@@ -31,6 +31,6 @@ dtrain <- lgb.Dataset(train$data, label = train$label)
dsub <- lightgbm::slice(dtrain, seq_len(42L)) dsub <- lightgbm::slice(dtrain, seq_len(42L))
lgb.Dataset.construct(dsub) lgb.Dataset.construct(dsub)
labels <- lightgbm::getinfo(dsub, "label") labels <- lightgbm::get_field(dsub, "label")
} }
} }
...@@ -56,8 +56,8 @@ reference: ...@@ -56,8 +56,8 @@ reference:
contents: contents:
- '`dim.lgb.Dataset`' - '`dim.lgb.Dataset`'
- '`dimnames.lgb.Dataset`' - '`dimnames.lgb.Dataset`'
- '`getinfo`' - '`get_field`'
- '`setinfo`' - '`set_field`'
- '`slice`' - '`slice`'
- '`lgb.Dataset`' - '`lgb.Dataset`'
- '`lgb.Dataset.construct`' - '`lgb.Dataset.construct`'
......
...@@ -9,7 +9,7 @@ watchlist <- list(eval = dtest, train = dtrain) ...@@ -9,7 +9,7 @@ watchlist <- list(eval = dtest, train = dtrain)
TOLERANCE <- 1e-6 TOLERANCE <- 1e-6
logregobj <- function(preds, dtrain) { logregobj <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label") labels <- get_field(dtrain, "label")
preds <- 1.0 / (1.0 + exp(-preds)) preds <- 1.0 / (1.0 + exp(-preds))
grad <- preds - labels grad <- preds - labels
hess <- preds * (1.0 - preds) hess <- preds * (1.0 - preds)
...@@ -21,7 +21,7 @@ logregobj <- function(preds, dtrain) { ...@@ -21,7 +21,7 @@ logregobj <- function(preds, dtrain) {
# This may make built-in evalution metric calculate wrong results # This may make built-in evalution metric calculate wrong results
# Keep this in mind when you use the customization, and maybe you need write customized evaluation function # Keep this in mind when you use the customization, and maybe you need write customized evaluation function
evalerror <- function(preds, dtrain) { evalerror <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label") labels <- get_field(dtrain, "label")
preds <- 1.0 / (1.0 + exp(-preds)) preds <- 1.0 / (1.0 + exp(-preds))
err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels) err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels)
return(list( return(list(
......
...@@ -14,6 +14,7 @@ test_that("lgb.Dataset: basic construction, saving, loading", { ...@@ -14,6 +14,7 @@ test_that("lgb.Dataset: basic construction, saving, loading", {
# from dense matrix # from dense matrix
dtest2 <- lgb.Dataset(as.matrix(test_data), label = test_label) dtest2 <- lgb.Dataset(as.matrix(test_data), label = test_label)
expect_equal(getinfo(dtest1, "label"), getinfo(dtest2, "label")) expect_equal(getinfo(dtest1, "label"), getinfo(dtest2, "label"))
expect_equal(get_field(dtest1, "label"), get_field(dtest2, "label"))
# save to a local file # save to a local file
tmp_file <- tempfile("lgb.Dataset_") tmp_file <- tempfile("lgb.Dataset_")
...@@ -23,6 +24,7 @@ test_that("lgb.Dataset: basic construction, saving, loading", { ...@@ -23,6 +24,7 @@ test_that("lgb.Dataset: basic construction, saving, loading", {
lgb.Dataset.construct(dtest3) lgb.Dataset.construct(dtest3)
unlink(tmp_file) unlink(tmp_file)
expect_equal(getinfo(dtest1, "label"), getinfo(dtest3, "label")) expect_equal(getinfo(dtest1, "label"), getinfo(dtest3, "label"))
expect_equal(get_field(dtest1, "label"), get_field(dtest3, "label"))
}) })
test_that("lgb.Dataset: getinfo & setinfo", { test_that("lgb.Dataset: getinfo & setinfo", {
...@@ -40,6 +42,21 @@ test_that("lgb.Dataset: getinfo & setinfo", { ...@@ -40,6 +42,21 @@ test_that("lgb.Dataset: getinfo & setinfo", {
expect_error(setinfo(dtest, "asdf", test_label)) expect_error(setinfo(dtest, "asdf", test_label))
}) })
test_that("lgb.Dataset: get_field & set_field", {
dtest <- lgb.Dataset(test_data)
dtest$construct()
set_field(dtest, "label", test_label)
labels <- get_field(dtest, "label")
expect_equal(test_label, get_field(dtest, "label"))
expect_true(length(get_field(dtest, "weight")) == 0L)
expect_true(length(get_field(dtest, "init_score")) == 0L)
# any other label should error
expect_error(set_field(dtest, "asdf", test_label))
})
test_that("lgb.Dataset: slice, dim", { test_that("lgb.Dataset: slice, dim", {
dtest <- lgb.Dataset(test_data, label = test_label) dtest <- lgb.Dataset(test_data, label = test_label)
lgb.Dataset.construct(dtest) lgb.Dataset.construct(dtest)
...@@ -255,6 +272,19 @@ test_that("lgb.Dataset$setinfo() should convert 'group' to integer", { ...@@ -255,6 +272,19 @@ test_that("lgb.Dataset$setinfo() should convert 'group' to integer", {
expect_identical(ds$getinfo("group"), as.integer(group_as_numeric)) expect_identical(ds$getinfo("group"), as.integer(group_as_numeric))
}) })
test_that("lgb.Dataset$set_field() should convert 'group' to integer", {
ds <- lgb.Dataset(
data = matrix(rnorm(100L), nrow = 50L, ncol = 2L)
, label = sample(c(0L, 1L), size = 50L, replace = TRUE)
)
ds$construct()
current_group <- ds$get_field("group")
expect_null(current_group)
group_as_numeric <- rep(25.0, 2L)
ds$set_field("group", group_as_numeric)
expect_identical(ds$get_field("group"), as.integer(group_as_numeric))
})
test_that("lgb.Dataset should throw an error if 'reference' is provided but of the wrong format", { test_that("lgb.Dataset should throw an error if 'reference' is provided but of the wrong format", {
data(agaricus.test, package = "lightgbm") data(agaricus.test, package = "lightgbm")
test_data <- agaricus.test$data[1L:100L, ] test_data <- agaricus.test$data[1L:100L, ]
......
...@@ -11,10 +11,10 @@ test_that("lgb.intereprete works as expected for binary classification", { ...@@ -11,10 +11,10 @@ test_that("lgb.intereprete works as expected for binary classification", {
data(agaricus.train, package = "lightgbm") data(agaricus.train, package = "lightgbm")
train <- agaricus.train train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label) dtrain <- lgb.Dataset(train$data, label = train$label)
setinfo( set_field(
dataset = dtrain dataset = dtrain
, "init_score" , field_name = "init_score"
, rep( , data = rep(
.logit(mean(train$label)) .logit(mean(train$label))
, length(train$label) , length(train$label)
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment