"vscode:/vscode.git/clone" did not exist on "d4cdbcfe1b1db2d48e95665a8b358028c496cdb2"
Commit 06a915a7 authored by Laurae's avatar Laurae Committed by Guolin Ke
Browse files

[R-package] Add $raw to LightGBM (#347)

* Update lgb.Booster.R

* Add saveRDS (manual)

* Add documentation

* Change arguments passed from debug.

* Add readRDS, change way of saving.

* Change documentation.

* Add better documentation.
parent cc62b1c3
...@@ -26,6 +26,8 @@ export(lgb.plot.interpretation) ...@@ -26,6 +26,8 @@ export(lgb.plot.interpretation)
export(lgb.save) export(lgb.save)
export(lgb.train) export(lgb.train)
export(lightgbm) export(lightgbm)
export(readRDS.lgb.Booster)
export(saveRDS.lgb.Booster)
export(setinfo) export(setinfo)
export(slice) export(slice)
import(methods) import(methods)
......
...@@ -195,7 +195,14 @@ Booster <- R6Class( ...@@ -195,7 +195,14 @@ Booster <- R6Class(
predictor <- Predictor$new(private$handle) predictor <- Predictor$new(private$handle)
predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape) predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape)
}, },
to_predictor = function() { Predictor$new(private$handle) } to_predictor = function() { Predictor$new(private$handle) },
raw = NA,
save = function() {
temp <- tempfile()
lgb.save(self, temp)
self$raw <- readChar(temp, file.info(temp)$size)
file.remove(temp)
}
), ),
private = list( private = list(
handle = NULL, handle = NULL,
......
#' readRDS for lgb.Booster models
#'
#' Attemps to load a model using RDS.
#'
#' @param file a connection or the name of the file where the R object is saved to or read from.
#' @param refhook a hook function for handling reference objects.
#'
#' @return an R object.
#'
#' @examples
#' \dontrun{
#' library(lightgbm)
#' data(agaricus.train, package='lightgbm')
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label=train$label)
#' data(agaricus.test, package='lightgbm')
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#' params <- list(objective="regression", metric="l2")
#' valids <- list(test=dtest)
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' saveRDS.lgb.Booster(model, "model.rds")
#' new_model <- readRDS.lgb.Booster("model.rds")
#' }
#' @export
readRDS.lgb.Booster <- function(file = "", refhook = NULL) {
object <- readRDS(file = file, refhook = refhook)
if (!is.na(object$raw)) {
temp <- tempfile()
write(object$raw, temp)
object2 <- lgb.load(temp)
file.remove(temp)
object2$best_iter <- object$best_iter
object2$record_evals <- object$record_evals
return(object2)
} else {
return(object)
}
}
#' saveRDS for lgb.Booster models
#'
#' Attemps to save a model using RDS. Has an additional parameter (\code{raw}) which decides whether to save the raw model or not.
#'
#' @param object R object to serialize.
#' @param file a connection or the name of the file where the R object is saved to or read from.
#' @param ascii a logical. If TRUE or NA, an ASCII representation is written; otherwise (default), a binary one is used. See the comments in the help for save.
#' @param version the workspace format version to use. \code{NULL} specifies the current default version (2). Versions prior to 2 are not supported, so this will only be relevant when there are later versions.
#' @param compress a logical specifying whether saving to a named file is to use "gzip" compression, or one of \code{"gzip"}, \code{"bzip2"} or \code{"xz"} to indicate the type of compression to be used. Ignored if file is a connection.
#' @param refhook a hook function for handling reference objects.
#' @param raw whether to save the model in a raw variable or not, recommended to leave it to \code{TRUE}.
#'
#' @return NULL invisibly.
#'
#' @examples
#' \dontrun{
#' library(lightgbm)
#' data(agaricus.train, package='lightgbm')
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label=train$label)
#' data(agaricus.test, package='lightgbm')
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#' params <- list(objective="regression", metric="l2")
#' valids <- list(test=dtest)
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' saveRDS.lgb.Booster(model, "model.rds")
#' }
#' @export
saveRDS.lgb.Booster <- function(object, file = "", ascii = FALSE, version = NULL, compress = TRUE, refhook = NULL, raw = TRUE) {
if (is.na(object$raw) & (raw)) {
object$save()
saveRDS(object, file = file, ascii = ascii, version = version, compress = compress, refhook = refhook)
object$raw <- NA
} else {
saveRDS(object, file = file, ascii = ascii, version = version, compress = compress, refhook = refhook)
}
}
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/readRDS.lgb.Booster.R
\name{readRDS.lgb.Booster}
\alias{readRDS.lgb.Booster}
\title{readRDS for lgb.Booster models}
\usage{
readRDS.lgb.Booster(file = "", refhook = NULL)
}
\arguments{
\item{file}{a connection or the name of the file where the R object is saved to or read from.}
\item{refhook}{a hook function for handling reference objects.}
}
\value{
an R object.
}
\description{
Attemps to load a model using RDS.
}
\examples{
\dontrun{
library(lightgbm)
data(agaricus.train, package='lightgbm')
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label=train$label)
data(agaricus.test, package='lightgbm')
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
params <- list(objective="regression", metric="l2")
valids <- list(test=dtest)
model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
saveRDS.lgb.Booster(model, "model.rds")
new_model <- readRDS.lgb.Booster("model.rds")
}
}
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/saveRDS.lgb.Booster.R
\name{saveRDS.lgb.Booster}
\alias{saveRDS.lgb.Booster}
\title{saveRDS for lgb.Booster models}
\usage{
saveRDS.lgb.Booster(object, file = "", ascii = FALSE, version = NULL,
compress = TRUE, refhook = NULL, raw = TRUE)
}
\arguments{
\item{object}{R object to serialize.}
\item{file}{a connection or the name of the file where the R object is saved to or read from.}
\item{ascii}{a logical. If TRUE or NA, an ASCII representation is written; otherwise (default), a binary one is used. See the comments in the help for save.}
\item{version}{the workspace format version to use. \code{NULL} specifies the current default version (2). Versions prior to 2 are not supported, so this will only be relevant when there are later versions.}
\item{compress}{a logical specifying whether saving to a named file is to use "gzip" compression, or one of \code{"gzip"}, \code{"bzip2"} or \code{"xz"} to indicate the type of compression to be used. Ignored if file is a connection.}
\item{refhook}{a hook function for handling reference objects.}
\item{raw}{whether to save the model in a raw variable or not, recommended to leave it to \code{TRUE}.}
}
\value{
NULL invisibly.
}
\description{
Attemps to save a model using RDS. Has an additional parameter (\code{raw}) which decides whether to save the raw model or not.
}
\examples{
\dontrun{
library(lightgbm)
data(agaricus.train, package='lightgbm')
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label=train$label)
data(agaricus.test, package='lightgbm')
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
params <- list(objective="regression", metric="l2")
valids <- list(test=dtest)
model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
saveRDS.lgb.Booster(model, "model.rds")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment