readRDS.lgb.Booster.R 1.7 KB
Newer Older
1
2
3
#' @name readRDS.lgb.Booster
#' @title readRDS for \code{lgb.Booster} models
#' @description Attempts to load a model stored in a \code{.rds} file, using \code{\link[base]{readRDS}}
4
5
#' @param file a connection or the name of the file where the R object is saved to or read from.
#' @param refhook a hook function for handling reference objects.
6
#'
7
#' @return \code{lgb.Booster}
8
#'
9
#' @examples
10
#' \donttest{
11
12
13
14
15
16
17
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
18
19
20
21
22
23
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
24
#' valids <- list(test = dtest)
25
26
27
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
28
#'   , nrounds = 10L
29
#'   , valids = valids
30
#'   , early_stopping_rounds = 5L
31
#' )
32
33
34
#' model_file <- tempfile(fileext = ".rds")
#' saveRDS.lgb.Booster(model, model_file)
#' new_model <- readRDS.lgb.Booster(model_file)
35
#' }
36
#' @export
37
readRDS.lgb.Booster <- function(file, refhook = NULL) {
38

39
  object <- readRDS(file = file, refhook = refhook)
40

41
  # Check if object has the model stored
42
  if (!is.na(object$raw)) {
43

44
45
    # Create temporary model for the model loading
    object2 <- lgb.load(model_str = object$raw)
46

47
    # Restore best iteration and recorded evaluations
48
49
    object2$best_iter <- object$best_iter
    object2$record_evals <- object$record_evals
50
    object2$params <- object$params
51

52
    # Return newly loaded object
53
    return(object2)
54

55
  } else {
56

57
    # Return RDS loaded object
58
    return(object)
59

60
  }
61

62
}