readRDS.lgb.Booster.R 1.64 KB
Newer Older
1
2
3
4
5
6
7
#' readRDS for lgb.Booster models
#'
#' Attemps to load a model using RDS.
#' 
#' @param file a connection or the name of the file where the R object is saved to or read from.
#' @param refhook a hook function for handling reference objects.
#' 
8
#' @return lgb.Booster.
9
10
11
12
#' 
#' @examples
#' \dontrun{
#'   library(lightgbm)
13
#'   data(agaricus.train, package = "lightgbm")
14
#'   train <- agaricus.train
15
16
#'   dtrain <- lgb.Dataset(train$data, label = train$label)
#'   data(agaricus.test, package = "lightgbm")
17
#'   test <- agaricus.test
18
19
20
21
22
23
24
25
26
27
#'   dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#'   params <- list(objective = "regression", metric = "l2")
#'   valids <- list(test = dtest)
#'   model <- lgb.train(params,
#'                      dtrain,
#'                      100,
#'                      valids,
#'                      min_data = 1,
#'                      learning_rate = 1,
#'                      early_stopping_rounds = 10)
28
29
30
#'   saveRDS.lgb.Booster(model, "model.rds")
#'   new_model <- readRDS.lgb.Booster("model.rds")
#' }
31
#' 
32
33
34
#' @export
readRDS.lgb.Booster <- function(file = "", refhook = NULL) {
  
35
  # Read RDS file
36
  object <- readRDS(file = file, refhook = refhook)
37
38
  
  # Check if object has the model stored
39
  if (!is.na(object$raw)) {
40
    
41
42
    # Create temporary model for the model loading
    object2 <- lgb.load(model_str = object$raw)
43
44
    
    # Restore best iteration and recorded evaluations
45
46
    object2$best_iter <- object$best_iter
    object2$record_evals <- object$record_evals
47
48
    
    # Return newly loaded object
49
    return(object2)
50
    
51
  } else {
52
53
    
    # Return RDS loaded object
54
    return(object)
55
    
56
57
58
  }
  
}