readRDS.lgb.Booster.R 1.52 KB
Newer Older
Nikita Titov's avatar
Nikita Titov committed
1
#' readRDS for \code{lgb.Booster} models
2
#'
James Lamb's avatar
James Lamb committed
3
#' Attempts to load a model using RDS.
4
#'
5
6
#' @param file a connection or the name of the file where the R object is saved to or read from.
#' @param refhook a hook function for handling reference objects.
7
#'
Nikita Titov's avatar
Nikita Titov committed
8
#' @return \code{lgb.Booster}.
9
#'
10
#' @examples
11
12
13
14
15
16
17
18
19
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
20
21
22
23
24
25
26
27
28
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
#'   , nrounds = 10
#'   , valids = valids
#'   , min_data = 1
#'   , learning_rate = 1
#'   , early_stopping_rounds = 5
#' )
29
30
#' saveRDS.lgb.Booster(model, "model.rds")
#' new_model <- readRDS.lgb.Booster("model.rds")
31
#'
32
33
#' @export
readRDS.lgb.Booster <- function(file = "", refhook = NULL) {
34

35
  # Read RDS file
36
  object <- readRDS(file = file, refhook = refhook)
37

38
  # Check if object has the model stored
39
  if (!is.na(object$raw)) {
40

41
42
    # Create temporary model for the model loading
    object2 <- lgb.load(model_str = object$raw)
43

44
    # Restore best iteration and recorded evaluations
45
46
    object2$best_iter <- object$best_iter
    object2$record_evals <- object$record_evals
47

48
    # Return newly loaded object
49
    return(object2)
50

51
  } else {
52

53
    # Return RDS loaded object
54
    return(object)
55

56
  }
57

58
}