saveRDS.lgb.Booster.R 2.7 KB
Newer Older
Nikita Titov's avatar
Nikita Titov committed
1
#' saveRDS for \code{lgb.Booster} models
2
#'
3
4
#' Attempts to save a model using RDS. Has an additional parameter (\code{raw}) which decides
#' whether to save the raw model or not.
5
#'
6
7
#' @param object R object to serialize.
#' @param file a connection or the name of the file where the R object is saved to or read from.
8
9
10
11
12
13
14
15
#' @param ascii a logical. If TRUE or NA, an ASCII representation is written; otherwise (default),
#'              a binary one is used. See the comments in the help for save.
#' @param version the workspace format version to use. \code{NULL} specifies the current default
#'                version (2). Versions prior to 2 are not supported, so this will only be relevant
#'                when there are later versions.
#' @param compress a logical specifying whether saving to a named file is to use "gzip" compression,
#'                 or one of \code{"gzip"}, \code{"bzip2"} or \code{"xz"} to indicate the type of
#'                 compression to be used. Ignored if file is a connection.
16
17
#' @param refhook a hook function for handling reference objects.
#' @param raw whether to save the model in a raw variable or not, recommended to leave it to \code{TRUE}.
18
#'
19
#' @return NULL invisibly.
20
#'
21
#' @examples
22
23
24
25
26
27
28
29
30
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
31
#' model <- lgb.train(
32
33
34
35
#'     params = params
#'     , data = dtrain
#'     , nrounds = 10
#'     , valids = valids
36
37
#'     , min_data = 1
#'     , learning_rate = 1
38
#'     , early_stopping_rounds = 5
39
#' )
40
#' saveRDS.lgb.Booster(model, "model.rds")
41
#' @export
42
43
44
45
46
47
48
saveRDS.lgb.Booster <- function(object,
                                file = "",
                                ascii = FALSE,
                                version = NULL,
                                compress = TRUE,
                                refhook = NULL,
                                raw = TRUE) {
49

50
  # Check if object has a raw value (and if the user wants to store the raw)
51
  if (is.na(object$raw) && raw) {
52

53
    # Save model
54
    object$save()
55

56
    # Save RDS
57
58
59
60
61
62
63
64
    saveRDS(
      object
      , file = file
      , ascii = ascii
      , version = version
      , compress = compress
      , refhook = refhook
    )
65

66
    # Free model from memory
67
    object$raw <- NA
68

69
  } else {
70

71
    # Save as usual
72
73
74
75
76
77
78
79
    saveRDS(
      object
      , file = file
      , ascii = ascii
      , version = version
      , compress = compress
      , refhook = refhook
    )
80

81
  }
82

83
}