saveRDS.lgb.Booster.R 2.78 KB
Newer Older
1
2
3
4
#' @name saveRDS.lgb.Booster
#' @title saveRDS for \code{lgb.Booster} models
#' @description Attempts to save a model using RDS. Has an additional parameter (\code{raw})
#'              which decides whether to save the raw model or not.
5
6
#' @param object R object to serialize.
#' @param file a connection or the name of the file where the R object is saved to or read from.
7
8
9
10
11
12
13
14
#' @param ascii a logical. If TRUE or NA, an ASCII representation is written; otherwise (default),
#'              a binary one is used. See the comments in the help for save.
#' @param version the workspace format version to use. \code{NULL} specifies the current default
#'                version (2). Versions prior to 2 are not supported, so this will only be relevant
#'                when there are later versions.
#' @param compress a logical specifying whether saving to a named file is to use "gzip" compression,
#'                 or one of \code{"gzip"}, \code{"bzip2"} or \code{"xz"} to indicate the type of
#'                 compression to be used. Ignored if file is a connection.
15
16
#' @param refhook a hook function for handling reference objects.
#' @param raw whether to save the model in a raw variable or not, recommended to leave it to \code{TRUE}.
17
#'
18
#' @return NULL invisibly.
19
#'
20
#' @examples
21
#' \donttest{
22
23
24
25
26
27
28
29
30
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
31
#' model <- lgb.train(
32
33
#'     params = params
#'     , data = dtrain
34
#'     , nrounds = 10L
35
#'     , valids = valids
36
37
38
#'     , min_data = 1L
#'     , learning_rate = 1.0
#'     , early_stopping_rounds = 5L
39
#' )
40
#' saveRDS.lgb.Booster(model, "model.rds")
41
#' }
42
#' @export
43
44
45
46
47
48
49
saveRDS.lgb.Booster <- function(object,
                                file = "",
                                ascii = FALSE,
                                version = NULL,
                                compress = TRUE,
                                refhook = NULL,
                                raw = TRUE) {
50

51
  # Check if object has a raw value (and if the user wants to store the raw)
52
  if (is.na(object$raw) && raw) {
53

54
    # Save model
55
    object$save()
56

57
    # Save RDS
58
59
60
61
62
63
64
65
    saveRDS(
      object
      , file = file
      , ascii = ascii
      , version = version
      , compress = compress
      , refhook = refhook
    )
66

67
    # Free model from memory
68
    object$raw <- NA
69

70
  } else {
71

72
    # Save as usual
73
74
75
76
77
78
79
80
    saveRDS(
      object
      , file = file
      , ascii = ascii
      , version = version
      , compress = compress
      , refhook = refhook
    )
81

82
  }
83

84
}