saveRDS.lgb.Booster.R 2.72 KB
Newer Older
1
2
#' saveRDS for lgb.Booster models
#'
James Lamb's avatar
James Lamb committed
3
#' Attempts to save a model using RDS. Has an additional parameter (\code{raw}) which decides whether to save the raw model or not.
4
5
6
7
8
9
10
11
12
13
14
15
16
#' 
#' @param object R object to serialize.
#' @param file a connection or the name of the file where the R object is saved to or read from.
#' @param ascii a logical. If TRUE or NA, an ASCII representation is written; otherwise (default), a binary one is used. See the comments in the help for save.
#' @param version the workspace format version to use. \code{NULL} specifies the current default version (2). Versions prior to 2 are not supported, so this will only be relevant when there are later versions.
#' @param compress a logical specifying whether saving to a named file is to use "gzip" compression, or one of \code{"gzip"}, \code{"bzip2"} or \code{"xz"} to indicate the type of compression to be used. Ignored if file is a connection.
#' @param refhook a hook function for handling reference objects.
#' @param raw whether to save the model in a raw variable or not, recommended to leave it to \code{TRUE}.
#' 
#' @return NULL invisibly.
#' 
#' @examples
#' \dontrun{
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                    dtrain,
#'                    100,
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
#'                    early_stopping_rounds = 10)
#' saveRDS.lgb.Booster(model, "model.rds")
34
#' }
35
#' 
36
#' @export
37
38
39
40
41
42
43
saveRDS.lgb.Booster <- function(object,
                                file = "",
                                ascii = FALSE,
                                version = NULL,
                                compress = TRUE,
                                refhook = NULL,
                                raw = TRUE) {
44
  
45
  # Check if object has a raw value (and if the user wants to store the raw)
46
  if (is.na(object$raw) && raw) {
47
48
    
    # Save model
49
    object$save()
50
51
52
53
54
55
56
57
58
59
    
    # Save RDS
    saveRDS(object,
            file = file,
            ascii = ascii,
            version = version,
            compress = compress,
            refhook = refhook)
    
    # Free model from memory
60
    object$raw <- NA
61
    
62
  } else {
63
64
65
66
67
68
69
70
71
    
    # Save as usual
    saveRDS(object,
            file = file,
            ascii = ascii,
            version = version,
            compress = compress,
            refhook = refhook)
    
72
73
74
  }
  
}