saveRDS.lgb.Booster.R 2.85 KB
Newer Older
1
2
3
4
#' @name saveRDS.lgb.Booster
#' @title saveRDS for \code{lgb.Booster} models
#' @description Attempts to save a model using RDS. Has an additional parameter (\code{raw})
#'              which decides whether to save the raw model or not.
5
#' @param object \code{lgb.Booster} object to serialize.
6
#' @param file a connection or the name of the file where the R object is saved to or read from.
7
8
9
10
11
12
13
14
#' @param ascii a logical. If TRUE or NA, an ASCII representation is written; otherwise (default),
#'              a binary one is used. See the comments in the help for save.
#' @param version the workspace format version to use. \code{NULL} specifies the current default
#'                version (2). Versions prior to 2 are not supported, so this will only be relevant
#'                when there are later versions.
#' @param compress a logical specifying whether saving to a named file is to use "gzip" compression,
#'                 or one of \code{"gzip"}, \code{"bzip2"} or \code{"xz"} to indicate the type of
#'                 compression to be used. Ignored if file is a connection.
15
16
#' @param refhook a hook function for handling reference objects.
#' @param raw whether to save the model in a raw variable or not, recommended to leave it to \code{TRUE}.
17
#'
18
#' @return NULL invisibly.
19
#'
20
#' @examples
21
#' \donttest{
22
23
24
25
26
27
28
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
29
30
31
32
33
34
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
35
#' valids <- list(test = dtest)
36
#' model <- lgb.train(
37
38
#'     params = params
#'     , data = dtrain
39
#'     , nrounds = 10L
40
#'     , valids = valids
41
#'     , early_stopping_rounds = 5L
42
#' )
43
44
#' model_file <- tempfile(fileext = ".rds")
#' saveRDS.lgb.Booster(model, model_file)
45
#' }
46
#' @export
47
saveRDS.lgb.Booster <- function(object,
48
                                file,
49
50
51
52
53
                                ascii = FALSE,
                                version = NULL,
                                compress = TRUE,
                                refhook = NULL,
                                raw = TRUE) {
54

55
  # Check if object has a raw value (and if the user wants to store the raw)
56
  if (is.na(object$raw) && raw) {
57

58
    object$save()
59

60
61
62
63
64
65
66
67
    saveRDS(
      object
      , file = file
      , ascii = ascii
      , version = version
      , compress = compress
      , refhook = refhook
    )
68

69
    # Free model from memory
70
    object$raw <- NA
71

72
73
    return(invisible(NULL))

74
  } else {
75

76
77
78
79
80
81
82
83
    saveRDS(
      object
      , file = file
      , ascii = ascii
      , version = version
      , compress = compress
      , refhook = refhook
    )
84

85
86
    return(invisible(NULL))

87
  }
88

89
}