saveRDS.lgb.Booster.R 2.77 KB
Newer Older
1
2
3
4
#' @name saveRDS.lgb.Booster
#' @title saveRDS for \code{lgb.Booster} models
#' @description Attempts to save a model using RDS. Has an additional parameter (\code{raw})
#'              which decides whether to save the raw model or not.
5
6
#' @param object R object to serialize.
#' @param file a connection or the name of the file where the R object is saved to or read from.
7
8
9
10
11
12
13
14
#' @param ascii a logical. If TRUE or NA, an ASCII representation is written; otherwise (default),
#'              a binary one is used. See the comments in the help for save.
#' @param version the workspace format version to use. \code{NULL} specifies the current default
#'                version (2). Versions prior to 2 are not supported, so this will only be relevant
#'                when there are later versions.
#' @param compress a logical specifying whether saving to a named file is to use "gzip" compression,
#'                 or one of \code{"gzip"}, \code{"bzip2"} or \code{"xz"} to indicate the type of
#'                 compression to be used. Ignored if file is a connection.
15
16
#' @param refhook a hook function for handling reference objects.
#' @param raw whether to save the model in a raw variable or not, recommended to leave it to \code{TRUE}.
17
#'
18
#' @return NULL invisibly.
19
#'
20
#' @examples
21
#' \donttest{
22
23
24
25
26
27
28
29
30
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
31
#' model <- lgb.train(
32
33
#'     params = params
#'     , data = dtrain
34
#'     , nrounds = 10L
35
#'     , valids = valids
36
37
38
#'     , min_data = 1L
#'     , learning_rate = 1.0
#'     , early_stopping_rounds = 5L
39
#' )
40
41
#' model_file <- tempfile(fileext = ".rds")
#' saveRDS.lgb.Booster(model, model_file)
42
#' }
43
#' @export
44
saveRDS.lgb.Booster <- function(object,
45
                                file,
46
47
48
49
50
                                ascii = FALSE,
                                version = NULL,
                                compress = TRUE,
                                refhook = NULL,
                                raw = TRUE) {
51

52
  # Check if object has a raw value (and if the user wants to store the raw)
53
  if (is.na(object$raw) && raw) {
54

55
    object$save()
56

57
58
59
60
61
62
63
64
    saveRDS(
      object
      , file = file
      , ascii = ascii
      , version = version
      , compress = compress
      , refhook = refhook
    )
65

66
    # Free model from memory
67
    object$raw <- NA
68

69
  } else {
70

71
72
73
74
75
76
77
78
    saveRDS(
      object
      , file = file
      , ascii = ascii
      , version = version
      , compress = compress
      , refhook = refhook
    )
79

80
  }
81

82
}