Commit 21575cb2 authored by James Lamb's avatar James Lamb Committed by Laurae
Browse files

updated R dependencies (#1619)

parent b2240721
...@@ -23,23 +23,23 @@ URL: https://github.com/Microsoft/LightGBM ...@@ -23,23 +23,23 @@ URL: https://github.com/Microsoft/LightGBM
BugReports: https://github.com/Microsoft/LightGBM/issues BugReports: https://github.com/Microsoft/LightGBM/issues
VignetteBuilder: knitr VignetteBuilder: knitr
Suggests: Suggests:
Ckmeans.1d.dp (>= 3.3.1),
DiagrammeR (>= 0.8.1),
ggplot2 (>= 1.0.1),
igraph (>= 1.0.1),
knitr, knitr,
rmarkdown, rmarkdown,
ggplot2 (>= 1.0.1), stringi (>= 0.5.2),
DiagrammeR (>= 0.8.1),
Ckmeans.1d.dp (>= 3.3.1),
vcd (>= 1.3),
testthat, testthat,
igraph (>= 1.0.1), vcd (>= 1.3)
stringi (>= 0.5.2)
Depends: Depends:
R (>= 3.0), R (>= 3.0),
R6 (>= 2.0) R6 (>= 2.0)
Imports: Imports:
graphics,
methods,
Matrix (>= 1.1-0),
data.table (>= 1.9.6), data.table (>= 1.9.6),
graphics,
jsonlite (>= 1.0),
magrittr (>= 1.5), magrittr (>= 1.5),
jsonlite (>= 1.0) Matrix (>= 1.1-0),
methods
RoxygenNote: 6.0.1 RoxygenNote: 6.0.1
...@@ -38,9 +38,15 @@ export(slice) ...@@ -38,9 +38,15 @@ export(slice)
import(methods) import(methods)
importFrom(R6,R6Class) importFrom(R6,R6Class)
importFrom(data.table,":=") importFrom(data.table,":=")
importFrom(data.table,data.table)
importFrom(data.table,rbindlist)
importFrom(data.table,set) importFrom(data.table,set)
importFrom(graphics,barplot) importFrom(graphics,barplot)
importFrom(graphics,par) importFrom(graphics,par)
importFrom(jsonlite,fromJSON)
importFrom(magrittr,"%>%") importFrom(magrittr,"%>%")
importFrom(magrittr,"%T>%") importFrom(magrittr,"%T>%")
importFrom(magrittr,extract)
importFrom(magrittr,inset)
importFrom(methods,is)
useDynLib(lib_lightgbm) useDynLib(lib_lightgbm)
#' @importFrom methods is
Dataset <- R6Class( Dataset <- R6Class(
classname = "lgb.Dataset", classname = "lgb.Dataset",
cloneable = FALSE, cloneable = FALSE,
public = list( public = list(
# Finalize will free up the handles # Finalize will free up the handles
finalize = function() { finalize = function() {
# Check the need for freeing handle # Check the need for freeing handle
if (!lgb.is.null.handle(private$handle)) { if (!lgb.is.null.handle(private$handle)) {
# Freeing up handle # Freeing up handle
lgb.call("LGBM_DatasetFree_R", ret = NULL, private$handle) lgb.call("LGBM_DatasetFree_R", ret = NULL, private$handle)
private$handle <- NULL private$handle <- NULL
} }
}, },
# Initialize will create a starter dataset # Initialize will create a starter dataset
initialize = function(data, initialize = function(data,
params = list(), params = list(),
...@@ -28,45 +30,45 @@ Dataset <- R6Class( ...@@ -28,45 +30,45 @@ Dataset <- R6Class(
used_indices = NULL, used_indices = NULL,
info = list(), info = list(),
...) { ...) {
# Check for additional parameters # Check for additional parameters
additional_params <- list(...) additional_params <- list(...)
# Create known attributes list # Create known attributes list
INFO_KEYS <- c("label", "weight", "init_score", "group") INFO_KEYS <- c("label", "weight", "init_score", "group")
# Check if attribute key is in the known attribute list # Check if attribute key is in the known attribute list
for (key in names(additional_params)) { for (key in names(additional_params)) {
# Key existing # Key existing
if (key %in% INFO_KEYS) { if (key %in% INFO_KEYS) {
# Store as info # Store as info
info[[key]] <- additional_params[[key]] info[[key]] <- additional_params[[key]]
} else { } else {
# Store as param # Store as param
params[[key]] <- additional_params[[key]] params[[key]] <- additional_params[[key]]
} }
} }
# Check for dataset reference # Check for dataset reference
if (!is.null(reference)) { if (!is.null(reference)) {
if (!lgb.check.r6.class(reference, "lgb.Dataset")) { if (!lgb.check.r6.class(reference, "lgb.Dataset")) {
stop("lgb.Dataset: Can only use ", sQuote("lgb.Dataset"), " as reference") stop("lgb.Dataset: Can only use ", sQuote("lgb.Dataset"), " as reference")
} }
} }
# Check for predictor reference # Check for predictor reference
if (!is.null(predictor)) { if (!is.null(predictor)) {
if (!lgb.check.r6.class(predictor, "lgb.Predictor")) { if (!lgb.check.r6.class(predictor, "lgb.Predictor")) {
stop("lgb.Dataset: Only can use ", sQuote("lgb.Predictor"), " as predictor") stop("lgb.Dataset: Only can use ", sQuote("lgb.Predictor"), " as predictor")
} }
} }
# Check for matrix format # Check for matrix format
if (is.matrix(data)) { if (is.matrix(data)) {
# Check whether matrix is the correct type first ("double") # Check whether matrix is the correct type first ("double")
...@@ -74,7 +76,7 @@ Dataset <- R6Class( ...@@ -74,7 +76,7 @@ Dataset <- R6Class(
storage.mode(data) <- "double" storage.mode(data) <- "double"
} }
} }
# Setup private attributes # Setup private attributes
private$raw_data <- data private$raw_data <- data
private$params <- params private$params <- params
...@@ -86,13 +88,13 @@ Dataset <- R6Class( ...@@ -86,13 +88,13 @@ Dataset <- R6Class(
private$free_raw_data <- free_raw_data private$free_raw_data <- free_raw_data
private$used_indices <- used_indices private$used_indices <- used_indices
private$info <- info private$info <- info
}, },
create_valid = function(data, create_valid = function(data,
info = list(), info = list(),
...) { ...) {
# Create new dataset # Create new dataset
ret <- Dataset$new(data, ret <- Dataset$new(data,
private$params, private$params,
...@@ -104,61 +106,61 @@ Dataset <- R6Class( ...@@ -104,61 +106,61 @@ Dataset <- R6Class(
NULL, NULL,
info, info,
...) ...)
# Return ret # Return ret
return(invisible(ret)) return(invisible(ret))
}, },
# Dataset constructor # Dataset constructor
construct = function() { construct = function() {
# Check for handle null # Check for handle null
if (!lgb.is.null.handle(private$handle)) { if (!lgb.is.null.handle(private$handle)) {
return(invisible(self)) return(invisible(self))
} }
# Get feature names # Get feature names
cnames <- NULL cnames <- NULL
if (is.matrix(private$raw_data) || is(private$raw_data, "dgCMatrix")) { if (is.matrix(private$raw_data) || methods::is(private$raw_data, "dgCMatrix")) {
cnames <- colnames(private$raw_data) cnames <- colnames(private$raw_data)
} }
# set feature names if not exist # set feature names if not exist
if (is.null(private$colnames) && !is.null(cnames)) { if (is.null(private$colnames) && !is.null(cnames)) {
private$colnames <- as.character(cnames) private$colnames <- as.character(cnames)
} }
# Get categorical feature index # Get categorical feature index
if (!is.null(private$categorical_feature)) { if (!is.null(private$categorical_feature)) {
# Check for character name # Check for character name
if (is.character(private$categorical_feature)) { if (is.character(private$categorical_feature)) {
cate_indices <- as.list(match(private$categorical_feature, private$colnames) - 1) cate_indices <- as.list(match(private$categorical_feature, private$colnames) - 1)
# Provided indices, but some indices are not existing? # Provided indices, but some indices are not existing?
if (sum(is.na(cate_indices)) > 0) { if (sum(is.na(cate_indices)) > 0) {
stop("lgb.self.get.handle: supplied an unknown feature in categorical_feature: ", sQuote(private$categorical_feature[is.na(cate_indices)])) stop("lgb.self.get.handle: supplied an unknown feature in categorical_feature: ", sQuote(private$categorical_feature[is.na(cate_indices)]))
} }
} else { } else {
# Check if more categorical features were output over the feature space # Check if more categorical features were output over the feature space
if (max(private$categorical_feature) > length(private$colnames)) { if (max(private$categorical_feature) > length(private$colnames)) {
stop("lgb.self.get.handle: supplied a too large value in categorical_feature: ", max(private$categorical_feature), " but only ", length(private$colnames), " features") stop("lgb.self.get.handle: supplied a too large value in categorical_feature: ", max(private$categorical_feature), " but only ", length(private$colnames), " features")
} }
# Store indices as [0, n-1] indexed instead of [1, n] indexed # Store indices as [0, n-1] indexed instead of [1, n] indexed
cate_indices <- as.list(private$categorical_feature - 1) cate_indices <- as.list(private$categorical_feature - 1)
} }
# Store indices for categorical features # Store indices for categorical features
private$params$categorical_feature <- cate_indices private$params$categorical_feature <- cate_indices
} }
# Check has header or not # Check has header or not
has_header <- FALSE has_header <- FALSE
if (!is.null(private$params$has_header) || !is.null(private$params$header)) { if (!is.null(private$params$has_header) || !is.null(private$params$header)) {
...@@ -166,31 +168,31 @@ Dataset <- R6Class( ...@@ -166,31 +168,31 @@ Dataset <- R6Class(
has_header <- TRUE has_header <- TRUE
} }
} }
# Generate parameter str # Generate parameter str
params_str <- lgb.params2str(private$params) params_str <- lgb.params2str(private$params)
# Get handle of reference dataset # Get handle of reference dataset
ref_handle <- NULL ref_handle <- NULL
if (!is.null(private$reference)) { if (!is.null(private$reference)) {
ref_handle <- private$reference$.__enclos_env__$private$get_handle() ref_handle <- private$reference$.__enclos_env__$private$get_handle()
} }
handle <- NA_real_ handle <- NA_real_
# Not subsetting # Not subsetting
if (is.null(private$used_indices)) { if (is.null(private$used_indices)) {
# Are we using a data file? # Are we using a data file?
if (is.character(private$raw_data)) { if (is.character(private$raw_data)) {
handle <- lgb.call("LGBM_DatasetCreateFromFile_R", handle <- lgb.call("LGBM_DatasetCreateFromFile_R",
ret = handle, ret = handle,
lgb.c_str(private$raw_data), lgb.c_str(private$raw_data),
params_str, params_str,
ref_handle) ref_handle)
} else if (is.matrix(private$raw_data)) { } else if (is.matrix(private$raw_data)) {
# Are we using a matrix? # Are we using a matrix?
handle <- lgb.call("LGBM_DatasetCreateFromMat_R", handle <- lgb.call("LGBM_DatasetCreateFromMat_R",
ret = handle, ret = handle,
...@@ -199,8 +201,8 @@ Dataset <- R6Class( ...@@ -199,8 +201,8 @@ Dataset <- R6Class(
ncol(private$raw_data), ncol(private$raw_data),
params_str, params_str,
ref_handle) ref_handle)
} else if (is(private$raw_data, "dgCMatrix")) { } else if (methods::is(private$raw_data, "dgCMatrix")) {
if (length(private$raw_data@p) > 2147483647) { if (length(private$raw_data@p) > 2147483647) {
stop("Cannot support large CSC matrix") stop("Cannot support large CSC matrix")
} }
...@@ -215,21 +217,21 @@ Dataset <- R6Class( ...@@ -215,21 +217,21 @@ Dataset <- R6Class(
nrow(private$raw_data), nrow(private$raw_data),
params_str, params_str,
ref_handle) ref_handle)
} else { } else {
# Unknown data type # Unknown data type
stop("lgb.Dataset.construct: does not support constructing from ", sQuote(class(private$raw_data))) stop("lgb.Dataset.construct: does not support constructing from ", sQuote(class(private$raw_data)))
} }
} else { } else {
# Reference is empty # Reference is empty
if (is.null(private$reference)) { if (is.null(private$reference)) {
stop("lgb.Dataset.construct: reference cannot be NULL for constructing data subset") stop("lgb.Dataset.construct: reference cannot be NULL for constructing data subset")
} }
# Construct subset # Construct subset
handle <- lgb.call("LGBM_DatasetGetSubset_R", handle <- lgb.call("LGBM_DatasetGetSubset_R",
ret = handle, ret = handle,
...@@ -237,7 +239,7 @@ Dataset <- R6Class( ...@@ -237,7 +239,7 @@ Dataset <- R6Class(
c(private$used_indices), # Adding c() fixes issue in R v3.5 c(private$used_indices), # Adding c() fixes issue in R v3.5
length(private$used_indices), length(private$used_indices),
params_str) params_str)
} }
if (lgb.is.null.handle(handle)) { if (lgb.is.null.handle(handle)) {
stop("lgb.Dataset.construct: cannot create Dataset handle") stop("lgb.Dataset.construct: cannot create Dataset handle")
...@@ -245,7 +247,7 @@ Dataset <- R6Class( ...@@ -245,7 +247,7 @@ Dataset <- R6Class(
# Setup class and private type # Setup class and private type
class(handle) <- "lgb.Dataset.handle" class(handle) <- "lgb.Dataset.handle"
private$handle <- handle private$handle <- handle
# Set feature names # Set feature names
if (!is.null(private$colnames)) { if (!is.null(private$colnames)) {
self$set_colnames(private$colnames) self$set_colnames(private$colnames)
...@@ -253,139 +255,139 @@ Dataset <- R6Class( ...@@ -253,139 +255,139 @@ Dataset <- R6Class(
# Load init score if requested # Load init score if requested
if (!is.null(private$predictor) && is.null(private$used_indices)) { if (!is.null(private$predictor) && is.null(private$used_indices)) {
# Setup initial scores # Setup initial scores
init_score <- private$predictor$predict(private$raw_data, rawscore = TRUE, reshape = TRUE) init_score <- private$predictor$predict(private$raw_data, rawscore = TRUE, reshape = TRUE)
# Not needed to transpose, for is col_marjor # Not needed to transpose, for is col_marjor
init_score <- as.vector(init_score) init_score <- as.vector(init_score)
private$info$init_score <- init_score private$info$init_score <- init_score
} }
# Should we free raw data? # Should we free raw data?
if (isTRUE(private$free_raw_data)) { if (isTRUE(private$free_raw_data)) {
private$raw_data <- NULL private$raw_data <- NULL
} }
# Get private information # Get private information
if (length(private$info) > 0) { if (length(private$info) > 0) {
# Set infos # Set infos
for (i in seq_along(private$info)) { for (i in seq_along(private$info)) {
p <- private$info[i] p <- private$info[i]
self$setinfo(names(p), p[[1]]) self$setinfo(names(p), p[[1]])
} }
} }
# Get label information existence # Get label information existence
if (is.null(self$getinfo("label"))) { if (is.null(self$getinfo("label"))) {
stop("lgb.Dataset.construct: label should be set") stop("lgb.Dataset.construct: label should be set")
} }
# Return self # Return self
return(invisible(self)) return(invisible(self))
}, },
# Dimension function # Dimension function
dim = function() { dim = function() {
# Check for handle # Check for handle
if (!lgb.is.null.handle(private$handle)) { if (!lgb.is.null.handle(private$handle)) {
num_row <- 0L num_row <- 0L
num_col <- 0L num_col <- 0L
# Get numeric data and numeric features # Get numeric data and numeric features
c(lgb.call("LGBM_DatasetGetNumData_R", ret = num_row, private$handle), c(lgb.call("LGBM_DatasetGetNumData_R", ret = num_row, private$handle),
lgb.call("LGBM_DatasetGetNumFeature_R", ret = num_col, private$handle)) lgb.call("LGBM_DatasetGetNumFeature_R", ret = num_col, private$handle))
} else if (is.matrix(private$raw_data) || is(private$raw_data, "dgCMatrix")) { } else if (is.matrix(private$raw_data) || methods::is(private$raw_data, "dgCMatrix")) {
# Check if dgCMatrix (sparse matrix column compressed) # Check if dgCMatrix (sparse matrix column compressed)
dim(private$raw_data) dim(private$raw_data)
} else { } else {
# Trying to work with unknown dimensions is not possible # Trying to work with unknown dimensions is not possible
stop("dim: cannot get dimensions before dataset has been constructed, please call lgb.Dataset.construct explicitly") stop("dim: cannot get dimensions before dataset has been constructed, please call lgb.Dataset.construct explicitly")
} }
}, },
# Get column names # Get column names
get_colnames = function() { get_colnames = function() {
# Check for handle # Check for handle
if (!lgb.is.null.handle(private$handle)) { if (!lgb.is.null.handle(private$handle)) {
# Get feature names and write them # Get feature names and write them
cnames <- lgb.call.return.str("LGBM_DatasetGetFeatureNames_R", private$handle) cnames <- lgb.call.return.str("LGBM_DatasetGetFeatureNames_R", private$handle)
private$colnames <- as.character(base::strsplit(cnames, "\t")[[1]]) private$colnames <- as.character(base::strsplit(cnames, "\t")[[1]])
private$colnames private$colnames
} else if (is.matrix(private$raw_data) || is(private$raw_data, "dgCMatrix")) { } else if (is.matrix(private$raw_data) || methods::is(private$raw_data, "dgCMatrix")) {
# Check if dgCMatrix (sparse matrix column compressed) # Check if dgCMatrix (sparse matrix column compressed)
colnames(private$raw_data) colnames(private$raw_data)
} else { } else {
# Trying to work with unknown dimensions is not possible # Trying to work with unknown dimensions is not possible
stop("dim: cannot get dimensions before dataset has been constructed, please call lgb.Dataset.construct explicitly") stop("dim: cannot get dimensions before dataset has been constructed, please call lgb.Dataset.construct explicitly")
} }
}, },
# Set column names # Set column names
set_colnames = function(colnames) { set_colnames = function(colnames) {
# Check column names non-existence # Check column names non-existence
if (is.null(colnames)) { if (is.null(colnames)) {
return(invisible(self)) return(invisible(self))
} }
# Check empty column names # Check empty column names
colnames <- as.character(colnames) colnames <- as.character(colnames)
if (length(colnames) == 0) { if (length(colnames) == 0) {
return(invisible(self)) return(invisible(self))
} }
# Write column names # Write column names
private$colnames <- colnames private$colnames <- colnames
if (!lgb.is.null.handle(private$handle)) { if (!lgb.is.null.handle(private$handle)) {
# Merge names with tab separation # Merge names with tab separation
merged_name <- paste0(as.list(private$colnames), collapse = "\t") merged_name <- paste0(as.list(private$colnames), collapse = "\t")
lgb.call("LGBM_DatasetSetFeatureNames_R", lgb.call("LGBM_DatasetSetFeatureNames_R",
ret = NULL, ret = NULL,
private$handle, private$handle,
lgb.c_str(merged_name)) lgb.c_str(merged_name))
} }
# Return self # Return self
return(invisible(self)) return(invisible(self))
}, },
# Get information # Get information
getinfo = function(name) { getinfo = function(name) {
# Create known attributes list # Create known attributes list
INFONAMES <- c("label", "weight", "init_score", "group") INFONAMES <- c("label", "weight", "init_score", "group")
# Check if attribute key is in the known attribute list # Check if attribute key is in the known attribute list
if (!is.character(name) || length(name) != 1 || !name %in% INFONAMES) { if (!is.character(name) || length(name) != 1 || !name %in% INFONAMES) {
stop("getinfo: name must one of the following: ", paste0(sQuote(INFONAMES), collapse = ", ")) stop("getinfo: name must one of the following: ", paste0(sQuote(INFONAMES), collapse = ", "))
} }
# Check for info name and handle # Check for info name and handle
if (is.null(private$info[[name]])) { if (is.null(private$info[[name]])) {
if (lgb.is.null.handle(private$handle)){ if (lgb.is.null.handle(private$handle)){
...@@ -397,10 +399,10 @@ Dataset <- R6Class( ...@@ -397,10 +399,10 @@ Dataset <- R6Class(
ret = info_len, ret = info_len,
private$handle, private$handle,
lgb.c_str(name)) lgb.c_str(name))
# Check if info is not empty # Check if info is not empty
if (info_len > 0) { if (info_len > 0) {
# Get back fields # Get back fields
ret <- NULL ret <- NULL
ret <- if (name == "group") { ret <- if (name == "group") {
...@@ -408,65 +410,65 @@ Dataset <- R6Class( ...@@ -408,65 +410,65 @@ Dataset <- R6Class(
} else { } else {
numeric(info_len) # Numeric numeric(info_len) # Numeric
} }
ret <- lgb.call("LGBM_DatasetGetField_R", ret <- lgb.call("LGBM_DatasetGetField_R",
ret = ret, ret = ret,
private$handle, private$handle,
lgb.c_str(name)) lgb.c_str(name))
private$info[[name]] <- ret private$info[[name]] <- ret
} }
} }
private$info[[name]] private$info[[name]]
}, },
# Set information # Set information
setinfo = function(name, info) { setinfo = function(name, info) {
# Create known attributes list # Create known attributes list
INFONAMES <- c("label", "weight", "init_score", "group") INFONAMES <- c("label", "weight", "init_score", "group")
# Check if attribute key is in the known attribute list # Check if attribute key is in the known attribute list
if (!is.character(name) || length(name) != 1 || !name %in% INFONAMES) { if (!is.character(name) || length(name) != 1 || !name %in% INFONAMES) {
stop("setinfo: name must one of the following: ", paste0(sQuote(INFONAMES), collapse = ", ")) stop("setinfo: name must one of the following: ", paste0(sQuote(INFONAMES), collapse = ", "))
} }
# Check for type of information # Check for type of information
info <- if (name == "group") { info <- if (name == "group") {
as.integer(info) # Integer as.integer(info) # Integer
} else { } else {
as.numeric(info) # Numeric as.numeric(info) # Numeric
} }
# Store information privately # Store information privately
private$info[[name]] <- info private$info[[name]] <- info
if (!lgb.is.null.handle(private$handle) && !is.null(info)) { if (!lgb.is.null.handle(private$handle) && !is.null(info)) {
if (length(info) > 0) { if (length(info) > 0) {
lgb.call("LGBM_DatasetSetField_R", lgb.call("LGBM_DatasetSetField_R",
ret = NULL, ret = NULL,
private$handle, private$handle,
lgb.c_str(name), lgb.c_str(name),
info, info,
length(info)) length(info))
} }
} }
# Return self # Return self
return(invisible(self)) return(invisible(self))
}, },
# Slice dataset # Slice dataset
slice = function(idxset, ...) { slice = function(idxset, ...) {
# Perform slicing # Perform slicing
Dataset$new(NULL, Dataset$new(NULL,
private$params, private$params,
...@@ -478,84 +480,84 @@ Dataset <- R6Class( ...@@ -478,84 +480,84 @@ Dataset <- R6Class(
idxset, idxset,
NULL, NULL,
...) ...)
}, },
# Update parameters # Update parameters
update_params = function(params) { update_params = function(params) {
# Parameter updating # Parameter updating
private$params <- modifyList(private$params, params) private$params <- modifyList(private$params, params)
return(invisible(self)) return(invisible(self))
}, },
# Set categorical feature parameter # Set categorical feature parameter
set_categorical_feature = function(categorical_feature) { set_categorical_feature = function(categorical_feature) {
# Check for identical input # Check for identical input
if (identical(private$categorical_feature, categorical_feature)) { if (identical(private$categorical_feature, categorical_feature)) {
return(invisible(self)) return(invisible(self))
} }
# Check for empty data # Check for empty data
if (is.null(private$raw_data)) { if (is.null(private$raw_data)) {
stop("set_categorical_feature: cannot set categorical feature after freeing raw data, stop("set_categorical_feature: cannot set categorical feature after freeing raw data,
please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset") please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset")
} }
# Overwrite categorical features # Overwrite categorical features
private$categorical_feature <- categorical_feature private$categorical_feature <- categorical_feature
# Finalize and return self # Finalize and return self
self$finalize() self$finalize()
return(invisible(self)) return(invisible(self))
}, },
# Set reference # Set reference
set_reference = function(reference) { set_reference = function(reference) {
# Set known references # Set known references
self$set_categorical_feature(reference$.__enclos_env__$private$categorical_feature) self$set_categorical_feature(reference$.__enclos_env__$private$categorical_feature)
self$set_colnames(reference$get_colnames()) self$set_colnames(reference$get_colnames())
private$set_predictor(reference$.__enclos_env__$private$predictor) private$set_predictor(reference$.__enclos_env__$private$predictor)
# Check for identical references # Check for identical references
if (identical(private$reference, reference)) { if (identical(private$reference, reference)) {
return(invisible(self)) return(invisible(self))
} }
# Check for empty data # Check for empty data
if (is.null(private$raw_data)) { if (is.null(private$raw_data)) {
stop("set_reference: cannot set reference after freeing raw data, stop("set_reference: cannot set reference after freeing raw data,
please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset") please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset")
} }
# Check for non-existing reference # Check for non-existing reference
if (!is.null(reference)) { if (!is.null(reference)) {
# Reference is unknown # Reference is unknown
if (!lgb.check.r6.class(reference, "lgb.Dataset")) { if (!lgb.check.r6.class(reference, "lgb.Dataset")) {
stop("set_reference: Can only use lgb.Dataset as a reference") stop("set_reference: Can only use lgb.Dataset as a reference")
} }
} }
# Store reference # Store reference
private$reference <- reference private$reference <- reference
# Finalize and return self # Finalize and return self
self$finalize() self$finalize()
return(invisible(self)) return(invisible(self))
}, },
# Save binary model # Save binary model
save_binary = function(fname) { save_binary = function(fname) {
# Store binary data # Store binary data
self$construct() self$construct()
lgb.call("LGBM_DatasetSaveBinary_R", lgb.call("LGBM_DatasetSaveBinary_R",
...@@ -564,7 +566,7 @@ Dataset <- R6Class( ...@@ -564,7 +566,7 @@ Dataset <- R6Class(
lgb.c_str(fname)) lgb.c_str(fname))
return(invisible(self)) return(invisible(self))
} }
), ),
private = list( private = list(
handle = NULL, handle = NULL,
...@@ -577,51 +579,51 @@ Dataset <- R6Class( ...@@ -577,51 +579,51 @@ Dataset <- R6Class(
free_raw_data = TRUE, free_raw_data = TRUE,
used_indices = NULL, used_indices = NULL,
info = NULL, info = NULL,
# Get handle # Get handle
get_handle = function() { get_handle = function() {
# Get handle and construct if needed # Get handle and construct if needed
if (lgb.is.null.handle(private$handle)) { if (lgb.is.null.handle(private$handle)) {
self$construct() self$construct()
} }
private$handle private$handle
}, },
# Set predictor # Set predictor
set_predictor = function(predictor) { set_predictor = function(predictor) {
# Return self is identical predictor # Return self is identical predictor
if (identical(private$predictor, predictor)) { if (identical(private$predictor, predictor)) {
return(invisible(self)) return(invisible(self))
} }
# Check for empty data # Check for empty data
if (is.null(private$raw_data)) { if (is.null(private$raw_data)) {
stop("set_predictor: cannot set predictor after free raw data, stop("set_predictor: cannot set predictor after free raw data,
please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset") please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset")
} }
# Check for empty predictor # Check for empty predictor
if (!is.null(predictor)) { if (!is.null(predictor)) {
# Predictor is unknown # Predictor is unknown
if (!lgb.check.r6.class(predictor, "lgb.Predictor")) { if (!lgb.check.r6.class(predictor, "lgb.Predictor")) {
stop("set_predictor: Can only use lgb.Predictor as predictor") stop("set_predictor: Can only use lgb.Predictor as predictor")
} }
} }
# Store predictor # Store predictor
private$predictor <- predictor private$predictor <- predictor
# Finalize and return self # Finalize and return self
self$finalize() self$finalize()
return(invisible(self)) return(invisible(self))
} }
) )
) )
...@@ -638,9 +640,9 @@ Dataset <- R6Class( ...@@ -638,9 +640,9 @@ Dataset <- R6Class(
#' @param free_raw_data TRUE for need to free raw data after construct #' @param free_raw_data TRUE for need to free raw data after construct
#' @param info a list of information of the lgb.Dataset object #' @param info a list of information of the lgb.Dataset object
#' @param ... other information to pass to \code{info} or parameters pass to \code{params} #' @param ... other information to pass to \code{info} or parameters pass to \code{params}
#' #'
#' @return constructed dataset #' @return constructed dataset
#' #'
#' @examples #' @examples
#' \dontrun{ #' \dontrun{
#' library(lightgbm) #' library(lightgbm)
...@@ -651,7 +653,7 @@ Dataset <- R6Class( ...@@ -651,7 +653,7 @@ Dataset <- R6Class(
#' dtrain <- lgb.Dataset("lgb.Dataset.data") #' dtrain <- lgb.Dataset("lgb.Dataset.data")
#' lgb.Dataset.construct(dtrain) #' lgb.Dataset.construct(dtrain)
#' } #' }
#' #'
#' @export #' @export
lgb.Dataset <- function(data, lgb.Dataset <- function(data,
params = list(), params = list(),
...@@ -661,7 +663,7 @@ lgb.Dataset <- function(data, ...@@ -661,7 +663,7 @@ lgb.Dataset <- function(data,
free_raw_data = TRUE, free_raw_data = TRUE,
info = list(), info = list(),
...) { ...) {
# Create new dataset # Create new dataset
invisible(Dataset$new(data, invisible(Dataset$new(data,
params, params,
...@@ -673,20 +675,20 @@ lgb.Dataset <- function(data, ...@@ -673,20 +675,20 @@ lgb.Dataset <- function(data,
NULL, NULL,
info, info,
...)) ...))
} }
#' Construct validation data #' Construct validation data
#' #'
#' Construct validation data according to training data #' Construct validation data according to training data
#' #'
#' @param dataset \code{lgb.Dataset} object, training data #' @param dataset \code{lgb.Dataset} object, training data
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename #' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param info a list of information of the lgb.Dataset object #' @param info a list of information of the lgb.Dataset object
#' @param ... other information to pass to \code{info}. #' @param ... other information to pass to \code{info}.
#' #'
#' @return constructed dataset #' @return constructed dataset
#' #'
#' @examples #' @examples
#' \dontrun{ #' \dontrun{
#' library(lightgbm) #' library(lightgbm)
...@@ -697,24 +699,24 @@ lgb.Dataset <- function(data, ...@@ -697,24 +699,24 @@ lgb.Dataset <- function(data,
#' test <- agaricus.test #' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label) #' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' } #' }
#' #'
#' @export #' @export
lgb.Dataset.create.valid <- function(dataset, data, info = list(), ...) { lgb.Dataset.create.valid <- function(dataset, data, info = list(), ...) {
# Check if dataset is not a dataset # Check if dataset is not a dataset
if (!lgb.is.Dataset(dataset)) { if (!lgb.is.Dataset(dataset)) {
stop("lgb.Dataset.create.valid: input data should be an lgb.Dataset object") stop("lgb.Dataset.create.valid: input data should be an lgb.Dataset object")
} }
# Create validation dataset # Create validation dataset
invisible(dataset$create_valid(data, info, ...)) invisible(dataset$create_valid(data, info, ...))
} }
#' Construct Dataset explicitly #' Construct Dataset explicitly
#' #'
#' @param dataset Object of class \code{lgb.Dataset} #' @param dataset Object of class \code{lgb.Dataset}
#' #'
#' @examples #' @examples
#' \dontrun{ #' \dontrun{
#' library(lightgbm) #' library(lightgbm)
...@@ -723,56 +725,56 @@ lgb.Dataset.create.valid <- function(dataset, data, info = list(), ...) { ...@@ -723,56 +725,56 @@ lgb.Dataset.create.valid <- function(dataset, data, info = list(), ...) {
#' dtrain <- lgb.Dataset(train$data, label = train$label) #' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.construct(dtrain) #' lgb.Dataset.construct(dtrain)
#' } #' }
#' #'
#' @export #' @export
lgb.Dataset.construct <- function(dataset) { lgb.Dataset.construct <- function(dataset) {
# Check if dataset is not a dataset # Check if dataset is not a dataset
if (!lgb.is.Dataset(dataset)) { if (!lgb.is.Dataset(dataset)) {
stop("lgb.Dataset.construct: input data should be an lgb.Dataset object") stop("lgb.Dataset.construct: input data should be an lgb.Dataset object")
} }
# Construct the dataset # Construct the dataset
invisible(dataset$construct()) invisible(dataset$construct())
} }
#' Dimensions of an lgb.Dataset #' Dimensions of an lgb.Dataset
#' #'
#' Returns a vector of numbers of rows and of columns in an \code{lgb.Dataset}. #' Returns a vector of numbers of rows and of columns in an \code{lgb.Dataset}.
#' @param x Object of class \code{lgb.Dataset} #' @param x Object of class \code{lgb.Dataset}
#' @param ... other parameters #' @param ... other parameters
#' #'
#' @return a vector of numbers of rows and of columns #' @return a vector of numbers of rows and of columns
#' #'
#' @details #' @details
#' Note: since \code{nrow} and \code{ncol} internally use \code{dim}, they can also #' Note: since \code{nrow} and \code{ncol} internally use \code{dim}, they can also
#' be directly used with an \code{lgb.Dataset} object. #' be directly used with an \code{lgb.Dataset} object.
#' #'
#' @examples #' @examples
#' \dontrun{ #' \dontrun{
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train #' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label) #' dtrain <- lgb.Dataset(train$data, label = train$label)
#' #'
#' stopifnot(nrow(dtrain) == nrow(train$data)) #' stopifnot(nrow(dtrain) == nrow(train$data))
#' stopifnot(ncol(dtrain) == ncol(train$data)) #' stopifnot(ncol(dtrain) == ncol(train$data))
#' stopifnot(all(dim(dtrain) == dim(train$data))) #' stopifnot(all(dim(dtrain) == dim(train$data)))
#' } #' }
#' #'
#' @rdname dim #' @rdname dim
#' @export #' @export
dim.lgb.Dataset <- function(x, ...) { dim.lgb.Dataset <- function(x, ...) {
# Check if dataset is not a dataset # Check if dataset is not a dataset
if (!lgb.is.Dataset(x)) { if (!lgb.is.Dataset(x)) {
stop("dim.lgb.Dataset: input data should be an lgb.Dataset object") stop("dim.lgb.Dataset: input data should be an lgb.Dataset object")
} }
# Return dimensions # Return dimensions
x$dim() x$dim()
} }
#' Handling of column names of \code{lgb.Dataset} #' Handling of column names of \code{lgb.Dataset}
...@@ -800,76 +802,76 @@ dim.lgb.Dataset <- function(x, ...) { ...@@ -800,76 +802,76 @@ dim.lgb.Dataset <- function(x, ...) {
#' colnames(dtrain) <- make.names(1:ncol(train$data)) #' colnames(dtrain) <- make.names(1:ncol(train$data))
#' print(dtrain, verbose = TRUE) #' print(dtrain, verbose = TRUE)
#' } #' }
#' #'
#' @rdname dimnames.lgb.Dataset #' @rdname dimnames.lgb.Dataset
#' @export #' @export
dimnames.lgb.Dataset <- function(x) { dimnames.lgb.Dataset <- function(x) {
# Check if dataset is not a dataset # Check if dataset is not a dataset
if (!lgb.is.Dataset(x)) { if (!lgb.is.Dataset(x)) {
stop("dimnames.lgb.Dataset: input data should be an lgb.Dataset object") stop("dimnames.lgb.Dataset: input data should be an lgb.Dataset object")
} }
# Return dimension names # Return dimension names
list(NULL, x$get_colnames()) list(NULL, x$get_colnames())
} }
#' @rdname dimnames.lgb.Dataset #' @rdname dimnames.lgb.Dataset
#' @export #' @export
`dimnames<-.lgb.Dataset` <- function(x, value) { `dimnames<-.lgb.Dataset` <- function(x, value) {
# Check if invalid element list # Check if invalid element list
if (!is.list(value) || length(value) != 2L) { if (!is.list(value) || length(value) != 2L) {
stop("invalid ", sQuote("value"), " given: must be a list of two elements") stop("invalid ", sQuote("value"), " given: must be a list of two elements")
} }
# Check for unknown row names # Check for unknown row names
if (!is.null(value[[1L]])) { if (!is.null(value[[1L]])) {
stop("lgb.Dataset does not have rownames") stop("lgb.Dataset does not have rownames")
} }
# Check for second value missing # Check for second value missing
if (is.null(value[[2]])) { if (is.null(value[[2]])) {
# No column names # No column names
x$set_colnames(NULL) x$set_colnames(NULL)
return(x) return(x)
} }
# Check for unmatching column size # Check for unmatching column size
if (ncol(x) != length(value[[2]])) { if (ncol(x) != length(value[[2]])) {
stop("can't assign ", sQuote(length(value[[2]])), " colnames to an lgb.Dataset with ", sQuote(ncol(x)), " columns") stop("can't assign ", sQuote(length(value[[2]])), " colnames to an lgb.Dataset with ", sQuote(ncol(x)), " columns")
} }
# Set column names properly, and return # Set column names properly, and return
x$set_colnames(value[[2]]) x$set_colnames(value[[2]])
x x
} }
#' Slice a dataset #' Slice a dataset
#' #'
#' Get a new \code{lgb.Dataset} containing the specified rows of #' Get a new \code{lgb.Dataset} containing the specified rows of
#' orginal lgb.Dataset object #' orginal lgb.Dataset object
#' #'
#' @param dataset Object of class "lgb.Dataset" #' @param dataset Object of class "lgb.Dataset"
#' @param idxset a integer vector of indices of rows needed #' @param idxset a integer vector of indices of rows needed
#' @param ... other parameters (currently not used) #' @param ... other parameters (currently not used)
#' @return constructed sub dataset #' @return constructed sub dataset
#' #'
#' @examples #' @examples
#' \dontrun{ #' \dontrun{
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train #' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label) #' dtrain <- lgb.Dataset(train$data, label = train$label)
#' #'
#' dsub <- lightgbm::slice(dtrain, 1:42) #' dsub <- lightgbm::slice(dtrain, 1:42)
#' labels <- lightgbm::getinfo(dsub, "label") #' labels <- lightgbm::getinfo(dsub, "label")
#' } #' }
#' #'
#' @export #' @export
slice <- function(dataset, ...) { slice <- function(dataset, ...) {
UseMethod("slice") UseMethod("slice")
...@@ -878,34 +880,34 @@ slice <- function(dataset, ...) { ...@@ -878,34 +880,34 @@ slice <- function(dataset, ...) {
#' @rdname slice #' @rdname slice
#' @export #' @export
slice.lgb.Dataset <- function(dataset, idxset, ...) { slice.lgb.Dataset <- function(dataset, idxset, ...) {
# Check if dataset is not a dataset # Check if dataset is not a dataset
if (!lgb.is.Dataset(dataset)) { if (!lgb.is.Dataset(dataset)) {
stop("slice.lgb.Dataset: input dataset should be an lgb.Dataset object") stop("slice.lgb.Dataset: input dataset should be an lgb.Dataset object")
} }
# Return sliced set # Return sliced set
invisible(dataset$slice(idxset, ...)) invisible(dataset$slice(idxset, ...))
} }
#' Get information of an lgb.Dataset object #' Get information of an lgb.Dataset object
#' #'
#' @param dataset Object of class \code{lgb.Dataset} #' @param dataset Object of class \code{lgb.Dataset}
#' @param name the name of the information field to get (see details) #' @param name the name of the information field to get (see details)
#' @param ... other parameters #' @param ... other parameters
#' @return info data #' @return info data
#' #'
#' @details #' @details
#' The \code{name} field can be one of the following: #' The \code{name} field can be one of the following:
#' #'
#' \itemize{ #' \itemize{
#' \item \code{label}: label lightgbm learn from ; #' \item \code{label}: label lightgbm learn from ;
#' \item \code{weight}: to do a weight rescale ; #' \item \code{weight}: to do a weight rescale ;
#' \item \code{group}: group size #' \item \code{group}: group size
#' \item \code{init_score}: initial score is the base prediction lightgbm will boost from ; #' \item \code{init_score}: initial score is the base prediction lightgbm will boost from ;
#' } #' }
#' #'
#' @examples #' @examples
#' \dontrun{ #' \dontrun{
#' library(lightgbm) #' library(lightgbm)
...@@ -913,14 +915,14 @@ slice.lgb.Dataset <- function(dataset, idxset, ...) { ...@@ -913,14 +915,14 @@ slice.lgb.Dataset <- function(dataset, idxset, ...) {
#' train <- agaricus.train #' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label) #' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.construct(dtrain) #' lgb.Dataset.construct(dtrain)
#' #'
#' labels <- lightgbm::getinfo(dtrain, "label") #' labels <- lightgbm::getinfo(dtrain, "label")
#' lightgbm::setinfo(dtrain, "label", 1 - labels) #' lightgbm::setinfo(dtrain, "label", 1 - labels)
#' #'
#' labels2 <- lightgbm::getinfo(dtrain, "label") #' labels2 <- lightgbm::getinfo(dtrain, "label")
#' stopifnot(all(labels2 == 1 - labels)) #' stopifnot(all(labels2 == 1 - labels))
#' } #' }
#' #'
#' @export #' @export
getinfo <- function(dataset, ...) { getinfo <- function(dataset, ...) {
UseMethod("getinfo") UseMethod("getinfo")
...@@ -929,35 +931,35 @@ getinfo <- function(dataset, ...) { ...@@ -929,35 +931,35 @@ getinfo <- function(dataset, ...) {
#' @rdname getinfo #' @rdname getinfo
#' @export #' @export
getinfo.lgb.Dataset <- function(dataset, name, ...) { getinfo.lgb.Dataset <- function(dataset, name, ...) {
# Check if dataset is not a dataset # Check if dataset is not a dataset
if (!lgb.is.Dataset(dataset)) { if (!lgb.is.Dataset(dataset)) {
stop("getinfo.lgb.Dataset: input dataset should be an lgb.Dataset object") stop("getinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
} }
# Return information # Return information
dataset$getinfo(name) dataset$getinfo(name)
} }
#' Set information of an lgb.Dataset object #' Set information of an lgb.Dataset object
#' #'
#' @param dataset Object of class "lgb.Dataset" #' @param dataset Object of class "lgb.Dataset"
#' @param name the name of the field to get #' @param name the name of the field to get
#' @param info the specific field of information to set #' @param info the specific field of information to set
#' @param ... other parameters #' @param ... other parameters
#' @return passed object #' @return passed object
#' #'
#' @details #' @details
#' The \code{name} field can be one of the following: #' The \code{name} field can be one of the following:
#' #'
#' \itemize{ #' \itemize{
#' \item \code{label}: label lightgbm learn from ; #' \item \code{label}: label lightgbm learn from ;
#' \item \code{weight}: to do a weight rescale ; #' \item \code{weight}: to do a weight rescale ;
#' \item \code{init_score}: initial score is the base prediction lightgbm will boost from ; #' \item \code{init_score}: initial score is the base prediction lightgbm will boost from ;
#' \item \code{group}. #' \item \code{group}.
#' } #' }
#' #'
#' @examples #' @examples
#' \dontrun{ #' \dontrun{
#' library(lightgbm) #' library(lightgbm)
...@@ -965,14 +967,14 @@ getinfo.lgb.Dataset <- function(dataset, name, ...) { ...@@ -965,14 +967,14 @@ getinfo.lgb.Dataset <- function(dataset, name, ...) {
#' train <- agaricus.train #' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label) #' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.construct(dtrain) #' lgb.Dataset.construct(dtrain)
#' #'
#' labels <- lightgbm::getinfo(dtrain, "label") #' labels <- lightgbm::getinfo(dtrain, "label")
#' lightgbm::setinfo(dtrain, "label", 1 - labels) #' lightgbm::setinfo(dtrain, "label", 1 - labels)
#' #'
#' labels2 <- lightgbm::getinfo(dtrain, "label") #' labels2 <- lightgbm::getinfo(dtrain, "label")
#' stopifnot(all.equal(labels2, 1 - labels)) #' stopifnot(all.equal(labels2, 1 - labels))
#' } #' }
#' #'
#' @export #' @export
setinfo <- function(dataset, ...) { setinfo <- function(dataset, ...) {
UseMethod("setinfo") UseMethod("setinfo")
...@@ -981,23 +983,23 @@ setinfo <- function(dataset, ...) { ...@@ -981,23 +983,23 @@ setinfo <- function(dataset, ...) {
#' @rdname setinfo #' @rdname setinfo
#' @export #' @export
setinfo.lgb.Dataset <- function(dataset, name, info, ...) { setinfo.lgb.Dataset <- function(dataset, name, info, ...) {
# Check if dataset is not a dataset # Check if dataset is not a dataset
if (!lgb.is.Dataset(dataset)) { if (!lgb.is.Dataset(dataset)) {
stop("setinfo.lgb.Dataset: input dataset should be an lgb.Dataset object") stop("setinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
} }
# Set information # Set information
invisible(dataset$setinfo(name, info)) invisible(dataset$setinfo(name, info))
} }
#' Set categorical feature of \code{lgb.Dataset} #' Set categorical feature of \code{lgb.Dataset}
#' #'
#' @param dataset object of class \code{lgb.Dataset} #' @param dataset object of class \code{lgb.Dataset}
#' @param categorical_feature categorical features #' @param categorical_feature categorical features
#' #'
#' @return passed dataset #' @return passed dataset
#' #'
#' @examples #' @examples
#' \dontrun{ #' \dontrun{
#' library(lightgbm) #' library(lightgbm)
...@@ -1008,30 +1010,30 @@ setinfo.lgb.Dataset <- function(dataset, name, info, ...) { ...@@ -1008,30 +1010,30 @@ setinfo.lgb.Dataset <- function(dataset, name, info, ...) {
#' dtrain <- lgb.Dataset("lgb.Dataset.data") #' dtrain <- lgb.Dataset("lgb.Dataset.data")
#' lgb.Dataset.set.categorical(dtrain, 1:2) #' lgb.Dataset.set.categorical(dtrain, 1:2)
#' } #' }
#' #'
#' @rdname lgb.Dataset.set.categorical #' @rdname lgb.Dataset.set.categorical
#' @export #' @export
lgb.Dataset.set.categorical <- function(dataset, categorical_feature) { lgb.Dataset.set.categorical <- function(dataset, categorical_feature) {
# Check if dataset is not a dataset # Check if dataset is not a dataset
if (!lgb.is.Dataset(dataset)) { if (!lgb.is.Dataset(dataset)) {
stop("lgb.Dataset.set.categorical: input dataset should be an lgb.Dataset object") stop("lgb.Dataset.set.categorical: input dataset should be an lgb.Dataset object")
} }
# Set categoricals # Set categoricals
invisible(dataset$set_categorical_feature(categorical_feature)) invisible(dataset$set_categorical_feature(categorical_feature))
} }
#' Set reference of \code{lgb.Dataset} #' Set reference of \code{lgb.Dataset}
#' #'
#' If you want to use validation data, you should set reference to training data #' If you want to use validation data, you should set reference to training data
#' #'
#' @param dataset object of class \code{lgb.Dataset} #' @param dataset object of class \code{lgb.Dataset}
#' @param reference object of class \code{lgb.Dataset} #' @param reference object of class \code{lgb.Dataset}
#' #'
#' @return passed dataset #' @return passed dataset
#' #'
#' @examples #' @examples
#' \dontrun{ #' \dontrun{
#' library(lightgbm) #' library(lightgbm)
...@@ -1043,29 +1045,29 @@ lgb.Dataset.set.categorical <- function(dataset, categorical_feature) { ...@@ -1043,29 +1045,29 @@ lgb.Dataset.set.categorical <- function(dataset, categorical_feature) {
#' dtest <- lgb.Dataset(test$data, test = train$label) #' dtest <- lgb.Dataset(test$data, test = train$label)
#' lgb.Dataset.set.reference(dtest, dtrain) #' lgb.Dataset.set.reference(dtest, dtrain)
#' } #' }
#' #'
#' @rdname lgb.Dataset.set.reference #' @rdname lgb.Dataset.set.reference
#' @export #' @export
lgb.Dataset.set.reference <- function(dataset, reference) { lgb.Dataset.set.reference <- function(dataset, reference) {
# Check if dataset is not a dataset # Check if dataset is not a dataset
if (!lgb.is.Dataset(dataset)) { if (!lgb.is.Dataset(dataset)) {
stop("lgb.Dataset.set.reference: input dataset should be an lgb.Dataset object") stop("lgb.Dataset.set.reference: input dataset should be an lgb.Dataset object")
} }
# Set reference # Set reference
invisible(dataset$set_reference(reference)) invisible(dataset$set_reference(reference))
} }
#' Save \code{lgb.Dataset} to a binary file #' Save \code{lgb.Dataset} to a binary file
#' #'
#' @param dataset object of class \code{lgb.Dataset} #' @param dataset object of class \code{lgb.Dataset}
#' @param fname object filename of output file #' @param fname object filename of output file
#' #'
#' @return passed dataset #' @return passed dataset
#' #'
#' @examples #' @examples
#' #'
#' \dontrun{ #' \dontrun{
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
...@@ -1073,21 +1075,21 @@ lgb.Dataset.set.reference <- function(dataset, reference) { ...@@ -1073,21 +1075,21 @@ lgb.Dataset.set.reference <- function(dataset, reference) {
#' dtrain <- lgb.Dataset(train$data, label = train$label) #' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.save(dtrain, "data.bin") #' lgb.Dataset.save(dtrain, "data.bin")
#' } #' }
#' #'
#' @rdname lgb.Dataset.save #' @rdname lgb.Dataset.save
#' @export #' @export
lgb.Dataset.save <- function(dataset, fname) { lgb.Dataset.save <- function(dataset, fname) {
# Check if dataset is not a dataset # Check if dataset is not a dataset
if (!lgb.is.Dataset(dataset)) { if (!lgb.is.Dataset(dataset)) {
stop("lgb.Dataset.set: input dataset should be an lgb.Dataset object") stop("lgb.Dataset.set: input dataset should be an lgb.Dataset object")
} }
# File-type is not matching # File-type is not matching
if (!is.character(fname)) { if (!is.character(fname)) {
stop("lgb.Dataset.set: fname should be a character or a file connection") stop("lgb.Dataset.set: fname should be a character or a file connection")
} }
# Store binary # Store binary
invisible(dataset$save_binary(fname)) invisible(dataset$save_binary(fname))
} }
#' @importFrom methods is
Predictor <- R6Class( Predictor <- R6Class(
classname = "lgb.Predictor", classname = "lgb.Predictor",
cloneable = FALSE, cloneable = FALSE,
public = list( public = list(
# Finalize will free up the handles # Finalize will free up the handles
finalize = function() { finalize = function() {
# Check the need for freeing handle # Check the need for freeing handle
if (private$need_free_handle && !lgb.is.null.handle(private$handle)) { if (private$need_free_handle && !lgb.is.null.handle(private$handle)) {
# Freeing up handle # Freeing up handle
lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle) lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
private$handle <- NULL private$handle <- NULL
} }
}, },
# Initialize will create a starter model # Initialize will create a starter model
initialize = function(modelfile, ...) { initialize = function(modelfile, ...) {
params <- list(...) params <- list(...)
private$params <- lgb.params2str(params) private$params <- lgb.params2str(params)
# Create new lgb handle # Create new lgb handle
handle <- 0.0 handle <- 0.0
# Check if handle is a character # Check if handle is a character
if (is.character(modelfile)) { if (is.character(modelfile)) {
# Create handle on it # Create handle on it
handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R", ret = handle, lgb.c_str(modelfile)) handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R", ret = handle, lgb.c_str(modelfile))
private$need_free_handle <- TRUE private$need_free_handle <- TRUE
} else if (is(modelfile, "lgb.Booster.handle")) { } else if (methods::is(modelfile, "lgb.Booster.handle")) {
# Check if model file is a booster handle already # Check if model file is a booster handle already
handle <- modelfile handle <- modelfile
private$need_free_handle <- FALSE private$need_free_handle <- FALSE
} else { } else {
# Model file is unknown # Model file is unknown
stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle") stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
} }
# Override class and store it # Override class and store it
class(handle) <- "lgb.Booster.handle" class(handle) <- "lgb.Booster.handle"
private$handle <- handle private$handle <- handle
}, },
# Get current iteration # Get current iteration
current_iter = function() { current_iter = function() {
cur_iter <- 0L cur_iter <- 0L
lgb.call("LGBM_BoosterGetCurrentIteration_R", ret = cur_iter, private$handle) lgb.call("LGBM_BoosterGetCurrentIteration_R", ret = cur_iter, private$handle)
}, },
# Predict from data # Predict from data
predict = function(data, predict = function(data,
num_iteration = NULL, num_iteration = NULL,
...@@ -66,22 +68,22 @@ Predictor <- R6Class( ...@@ -66,22 +68,22 @@ Predictor <- R6Class(
predcontrib = FALSE, predcontrib = FALSE,
header = FALSE, header = FALSE,
reshape = FALSE) { reshape = FALSE) {
# Check if number of iterations is existing - if not, then set it to -1 (use all) # Check if number of iterations is existing - if not, then set it to -1 (use all)
if (is.null(num_iteration)) { if (is.null(num_iteration)) {
num_iteration <- -1 num_iteration <- -1
} }
# Set temporary variable # Set temporary variable
num_row <- 0L num_row <- 0L
# Check if data is a file name # Check if data is a file name
if (is.character(data)) { if (is.character(data)) {
# Data is a filename, create a temporary file with a "lightgbm_" pattern in it # Data is a filename, create a temporary file with a "lightgbm_" pattern in it
tmp_filename <- tempfile(pattern = "lightgbm_") tmp_filename <- tempfile(pattern = "lightgbm_")
on.exit(unlink(tmp_filename), add = TRUE) on.exit(unlink(tmp_filename), add = TRUE)
# Predict from temporary file # Predict from temporary file
lgb.call("LGBM_BoosterPredictForFile_R", ret = NULL, private$handle, data, lgb.call("LGBM_BoosterPredictForFile_R", ret = NULL, private$handle, data,
as.integer(header), as.integer(header),
...@@ -91,19 +93,19 @@ Predictor <- R6Class( ...@@ -91,19 +93,19 @@ Predictor <- R6Class(
as.integer(num_iteration), as.integer(num_iteration),
private$params, private$params,
lgb.c_str(tmp_filename)) lgb.c_str(tmp_filename))
# Get predictions from file # Get predictions from file
preds <- read.delim(tmp_filename, header = FALSE, seq = "\t") preds <- read.delim(tmp_filename, header = FALSE, seq = "\t")
num_row <- nrow(preds) num_row <- nrow(preds)
preds <- as.vector(t(preds)) preds <- as.vector(t(preds))
} else { } else {
# Not a file, we need to predict from R object # Not a file, we need to predict from R object
num_row <- nrow(data) num_row <- nrow(data)
npred <- 0L npred <- 0L
# Check number of predictions to do # Check number of predictions to do
npred <- lgb.call("LGBM_BoosterCalcNumPredict_R", npred <- lgb.call("LGBM_BoosterCalcNumPredict_R",
ret = npred, ret = npred,
...@@ -113,10 +115,10 @@ Predictor <- R6Class( ...@@ -113,10 +115,10 @@ Predictor <- R6Class(
as.integer(predleaf), as.integer(predleaf),
as.integer(predcontrib), as.integer(predcontrib),
as.integer(num_iteration)) as.integer(num_iteration))
# Pre-allocate empty vector # Pre-allocate empty vector
preds <- numeric(npred) preds <- numeric(npred)
# Check if data is a matrix # Check if data is a matrix
if (is.matrix(data)) { if (is.matrix(data)) {
preds <- lgb.call("LGBM_BoosterPredictForMat_R", preds <- lgb.call("LGBM_BoosterPredictForMat_R",
...@@ -130,8 +132,8 @@ Predictor <- R6Class( ...@@ -130,8 +132,8 @@ Predictor <- R6Class(
as.integer(predcontrib), as.integer(predcontrib),
as.integer(num_iteration), as.integer(num_iteration),
private$params) private$params)
} else if (is(data, "dgCMatrix")) { } else if (methods::is(data, "dgCMatrix")) {
if (length(data@p) > 2147483647) { if (length(data@p) > 2147483647) {
stop("Cannot support large CSC matrix") stop("Cannot support large CSC matrix")
} }
...@@ -150,44 +152,44 @@ Predictor <- R6Class( ...@@ -150,44 +152,44 @@ Predictor <- R6Class(
as.integer(predcontrib), as.integer(predcontrib),
as.integer(num_iteration), as.integer(num_iteration),
private$params) private$params)
} else { } else {
# Cannot predict on unknown class # Cannot predict on unknown class
# to-do: predict from lgb.Dataset # to-do: predict from lgb.Dataset
stop("predict: cannot predict on data of class ", sQuote(class(data))) stop("predict: cannot predict on data of class ", sQuote(class(data)))
} }
} }
# Check if number of rows is strange (not a multiple of the dataset rows) # Check if number of rows is strange (not a multiple of the dataset rows)
if (length(preds) %% num_row != 0) { if (length(preds) %% num_row != 0) {
stop("predict: prediction length ", sQuote(length(preds))," is not a multiple of nrows(data): ", sQuote(num_row)) stop("predict: prediction length ", sQuote(length(preds))," is not a multiple of nrows(data): ", sQuote(num_row))
} }
# Get number of cases per row # Get number of cases per row
npred_per_case <- length(preds) / num_row npred_per_case <- length(preds) / num_row
# Data reshaping # Data reshaping
if (predleaf | predcontrib) { if (predleaf | predcontrib) {
# Predict leaves only, reshaping is mandatory # Predict leaves only, reshaping is mandatory
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE) preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
} else if (reshape && npred_per_case > 1) { } else if (reshape && npred_per_case > 1) {
# Predict with data reshaping # Predict with data reshaping
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE) preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
} }
# Return predictions # Return predictions
return(preds) return(preds)
} }
), ),
private = list(handle = NULL, private = list(handle = NULL,
need_free_handle = FALSE, need_free_handle = FALSE,
......
#' Parse a LightGBM model json dump #' Parse a LightGBM model json dump
#' #'
#' Parse a LightGBM model json dump into a \code{data.table} structure. #' Parse a LightGBM model json dump into a \code{data.table} structure.
#' #'
#' @param model object of class \code{lgb.Booster} #' @param model object of class \code{lgb.Booster}
#' @param num_iteration number of iterations you want to predict with. NULL or #' @param num_iteration number of iterations you want to predict with. NULL or
#' <= 0 means use best iteration #' <= 0 means use best iteration
#' #'
#' @return #' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs. #' A \code{data.table} with detailed information about model trees' nodes and leafs.
#' #'
#' The columns of the \code{data.table} are: #' The columns of the \code{data.table} are:
#' #'
#' \itemize{ #' \itemize{
#' \item \code{tree_index}: ID of a tree in a model (integer) #' \item \code{tree_index}: ID of a tree in a model (integer)
#' \item \code{split_index}: ID of a node in a tree (integer) #' \item \code{split_index}: ID of a node in a tree (integer)
...@@ -28,11 +28,11 @@ ...@@ -28,11 +28,11 @@
#' \item \code{leaf_value}: Leaf value #' \item \code{leaf_value}: Leaf value
#' \item \code{leaf_count}: The number of observation collected by a leaf #' \item \code{leaf_count}: The number of observation collected by a leaf
#' } #' }
#' #'
#' @examples #' @examples
#' \dontrun{ #' \dontrun{
#' library(lightgbm) #' library(lightgbm)
#' #'
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train #' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label) #' dtrain <- lgb.Dataset(train$data, label = train$label)
...@@ -45,43 +45,45 @@ ...@@ -45,43 +45,45 @@
#' #'
#' tree_dt <- lgb.model.dt.tree(model) #' tree_dt <- lgb.model.dt.tree(model)
#' } #' }
#' #'
#' @importFrom magrittr %>% #' @importFrom magrittr %>%
#' @importFrom data.table := #' @importFrom data.table := data.table
#' @importFrom jsonlite fromJSON
#' @export #' @export
lgb.model.dt.tree <- function(model, num_iteration = NULL) { lgb.model.dt.tree <- function(model, num_iteration = NULL) {
# Dump json model first # Dump json model first
json_model <- lgb.dump(model, num_iteration = num_iteration) json_model <- lgb.dump(model, num_iteration = num_iteration)
# Parse json model second # Parse json model second
parsed_json_model <- jsonlite::fromJSON(json_model, parsed_json_model <- jsonlite::fromJSON(json_model,
simplifyVector = TRUE, simplifyVector = TRUE,
simplifyDataFrame = FALSE, simplifyDataFrame = FALSE,
simplifyMatrix = FALSE, simplifyMatrix = FALSE,
flatten = FALSE) flatten = FALSE)
# Parse tree model third # Parse tree model third
tree_list <- lapply(parsed_json_model$tree_info, single.tree.parse) tree_list <- lapply(parsed_json_model$tree_info, single.tree.parse)
# Combine into single data.table fourth # Combine into single data.table fourth
tree_dt <- data.table::rbindlist(tree_list, use.names = TRUE) tree_dt <- data.table::rbindlist(tree_list, use.names = TRUE)
# Lookup sequence # Lookup sequence
tree_dt[, split_feature := Lookup(split_feature, tree_dt[, split_feature := Lookup(split_feature,
seq.int(from = 0, to = parsed_json_model$max_feature_idx), seq.int(from = 0, to = parsed_json_model$max_feature_idx),
parsed_json_model$feature_names)] parsed_json_model$feature_names)]
# Return tree # Return tree
return(tree_dt) return(tree_dt)
} }
#' @importFrom data.table data.table rbindlist
single.tree.parse <- function(lgb_tree) { single.tree.parse <- function(lgb_tree) {
# Traverse tree function # Traverse tree function
pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) { pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) {
if (is.null(env)) { if (is.null(env)) {
# Setup initial default data.table with default types # Setup initial default data.table with default types
env <- new.env(parent = emptyenv()) env <- new.env(parent = emptyenv())
...@@ -103,10 +105,10 @@ single.tree.parse <- function(lgb_tree) { ...@@ -103,10 +105,10 @@ single.tree.parse <- function(lgb_tree) {
# start tree traversal # start tree traversal
pre_order_traversal(env, tree_node_leaf, current_depth, parent_index) pre_order_traversal(env, tree_node_leaf, current_depth, parent_index)
} else { } else {
# Check if split index is not null in leaf # Check if split index is not null in leaf
if (!is.null(tree_node_leaf$split_index)) { if (!is.null(tree_node_leaf$split_index)) {
# update data.table # update data.table
env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt, env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
c(tree_node_leaf[c("split_index", c(tree_node_leaf[c("split_index",
...@@ -121,7 +123,7 @@ single.tree.parse <- function(lgb_tree) { ...@@ -121,7 +123,7 @@ single.tree.parse <- function(lgb_tree) {
"node_parent" = parent_index)), "node_parent" = parent_index)),
use.names = TRUE, use.names = TRUE,
fill = TRUE) fill = TRUE)
# Traverse tree again both left and right # Traverse tree again both left and right
pre_order_traversal(env, pre_order_traversal(env,
tree_node_leaf$left_child, tree_node_leaf$left_child,
...@@ -131,9 +133,9 @@ single.tree.parse <- function(lgb_tree) { ...@@ -131,9 +133,9 @@ single.tree.parse <- function(lgb_tree) {
tree_node_leaf$right_child, tree_node_leaf$right_child,
current_depth = current_depth + 1L, current_depth = current_depth + 1L,
parent_index = tree_node_leaf$split_index) parent_index = tree_node_leaf$split_index)
} else if (!is.null(tree_node_leaf$leaf_index)) { } else if (!is.null(tree_node_leaf$leaf_index)) {
# update data.table # update data.table
env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt, env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
c(tree_node_leaf[c("leaf_index", c(tree_node_leaf[c("leaf_index",
...@@ -143,29 +145,30 @@ single.tree.parse <- function(lgb_tree) { ...@@ -143,29 +145,30 @@ single.tree.parse <- function(lgb_tree) {
"leaf_parent" = parent_index)), "leaf_parent" = parent_index)),
use.names = TRUE, use.names = TRUE,
fill = TRUE) fill = TRUE)
} }
} }
return(env$single_tree_dt) return(env$single_tree_dt)
} }
# Traverse structure # Traverse structure
single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure) single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
# Store index # Store index
single_tree_dt[, tree_index := lgb_tree$tree_index] single_tree_dt[, tree_index := lgb_tree$tree_index]
# Return tree # Return tree
return(single_tree_dt) return(single_tree_dt)
} }
#' @importFrom magrittr %>% extract inset
Lookup <- function(key, key_lookup, value_lookup, missing = NA) { Lookup <- function(key, key_lookup, value_lookup, missing = NA) {
# Match key by looked up key # Match key by looked up key
match(key, key_lookup) %>% match(key, key_lookup) %>%
magrittr::extract(value_lookup, .) %>% magrittr::extract(value_lookup, .) %>%
magrittr::inset(. , is.na(.), missing) magrittr::inset(. , is.na(.), missing)
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment