Commit 21575cb2 authored by James Lamb's avatar James Lamb Committed by Laurae
Browse files

updated R dependencies (#1619)

parent b2240721
...@@ -23,23 +23,23 @@ URL: https://github.com/Microsoft/LightGBM ...@@ -23,23 +23,23 @@ URL: https://github.com/Microsoft/LightGBM
BugReports: https://github.com/Microsoft/LightGBM/issues BugReports: https://github.com/Microsoft/LightGBM/issues
VignetteBuilder: knitr VignetteBuilder: knitr
Suggests: Suggests:
Ckmeans.1d.dp (>= 3.3.1),
DiagrammeR (>= 0.8.1),
ggplot2 (>= 1.0.1),
igraph (>= 1.0.1),
knitr, knitr,
rmarkdown, rmarkdown,
ggplot2 (>= 1.0.1), stringi (>= 0.5.2),
DiagrammeR (>= 0.8.1),
Ckmeans.1d.dp (>= 3.3.1),
vcd (>= 1.3),
testthat, testthat,
igraph (>= 1.0.1), vcd (>= 1.3)
stringi (>= 0.5.2)
Depends: Depends:
R (>= 3.0), R (>= 3.0),
R6 (>= 2.0) R6 (>= 2.0)
Imports: Imports:
graphics,
methods,
Matrix (>= 1.1-0),
data.table (>= 1.9.6), data.table (>= 1.9.6),
graphics,
jsonlite (>= 1.0),
magrittr (>= 1.5), magrittr (>= 1.5),
jsonlite (>= 1.0) Matrix (>= 1.1-0),
methods
RoxygenNote: 6.0.1 RoxygenNote: 6.0.1
...@@ -38,9 +38,15 @@ export(slice) ...@@ -38,9 +38,15 @@ export(slice)
import(methods) import(methods)
importFrom(R6,R6Class) importFrom(R6,R6Class)
importFrom(data.table,":=") importFrom(data.table,":=")
importFrom(data.table,data.table)
importFrom(data.table,rbindlist)
importFrom(data.table,set) importFrom(data.table,set)
importFrom(graphics,barplot) importFrom(graphics,barplot)
importFrom(graphics,par) importFrom(graphics,par)
importFrom(jsonlite,fromJSON)
importFrom(magrittr,"%>%") importFrom(magrittr,"%>%")
importFrom(magrittr,"%T>%") importFrom(magrittr,"%T>%")
importFrom(magrittr,extract)
importFrom(magrittr,inset)
importFrom(methods,is)
useDynLib(lib_lightgbm) useDynLib(lib_lightgbm)
This diff is collapsed.
#' @importFrom methods is
Predictor <- R6Class( Predictor <- R6Class(
classname = "lgb.Predictor", classname = "lgb.Predictor",
cloneable = FALSE, cloneable = FALSE,
public = list( public = list(
# Finalize will free up the handles # Finalize will free up the handles
finalize = function() { finalize = function() {
# Check the need for freeing handle # Check the need for freeing handle
if (private$need_free_handle && !lgb.is.null.handle(private$handle)) { if (private$need_free_handle && !lgb.is.null.handle(private$handle)) {
# Freeing up handle # Freeing up handle
lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle) lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
private$handle <- NULL private$handle <- NULL
} }
}, },
# Initialize will create a starter model # Initialize will create a starter model
initialize = function(modelfile, ...) { initialize = function(modelfile, ...) {
params <- list(...) params <- list(...)
private$params <- lgb.params2str(params) private$params <- lgb.params2str(params)
# Create new lgb handle # Create new lgb handle
handle <- 0.0 handle <- 0.0
# Check if handle is a character # Check if handle is a character
if (is.character(modelfile)) { if (is.character(modelfile)) {
# Create handle on it # Create handle on it
handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R", ret = handle, lgb.c_str(modelfile)) handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R", ret = handle, lgb.c_str(modelfile))
private$need_free_handle <- TRUE private$need_free_handle <- TRUE
} else if (is(modelfile, "lgb.Booster.handle")) { } else if (methods::is(modelfile, "lgb.Booster.handle")) {
# Check if model file is a booster handle already # Check if model file is a booster handle already
handle <- modelfile handle <- modelfile
private$need_free_handle <- FALSE private$need_free_handle <- FALSE
} else { } else {
# Model file is unknown # Model file is unknown
stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle") stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
} }
# Override class and store it # Override class and store it
class(handle) <- "lgb.Booster.handle" class(handle) <- "lgb.Booster.handle"
private$handle <- handle private$handle <- handle
}, },
# Get current iteration # Get current iteration
current_iter = function() { current_iter = function() {
cur_iter <- 0L cur_iter <- 0L
lgb.call("LGBM_BoosterGetCurrentIteration_R", ret = cur_iter, private$handle) lgb.call("LGBM_BoosterGetCurrentIteration_R", ret = cur_iter, private$handle)
}, },
# Predict from data # Predict from data
predict = function(data, predict = function(data,
num_iteration = NULL, num_iteration = NULL,
...@@ -66,22 +68,22 @@ Predictor <- R6Class( ...@@ -66,22 +68,22 @@ Predictor <- R6Class(
predcontrib = FALSE, predcontrib = FALSE,
header = FALSE, header = FALSE,
reshape = FALSE) { reshape = FALSE) {
# Check if number of iterations is existing - if not, then set it to -1 (use all) # Check if number of iterations is existing - if not, then set it to -1 (use all)
if (is.null(num_iteration)) { if (is.null(num_iteration)) {
num_iteration <- -1 num_iteration <- -1
} }
# Set temporary variable # Set temporary variable
num_row <- 0L num_row <- 0L
# Check if data is a file name # Check if data is a file name
if (is.character(data)) { if (is.character(data)) {
# Data is a filename, create a temporary file with a "lightgbm_" pattern in it # Data is a filename, create a temporary file with a "lightgbm_" pattern in it
tmp_filename <- tempfile(pattern = "lightgbm_") tmp_filename <- tempfile(pattern = "lightgbm_")
on.exit(unlink(tmp_filename), add = TRUE) on.exit(unlink(tmp_filename), add = TRUE)
# Predict from temporary file # Predict from temporary file
lgb.call("LGBM_BoosterPredictForFile_R", ret = NULL, private$handle, data, lgb.call("LGBM_BoosterPredictForFile_R", ret = NULL, private$handle, data,
as.integer(header), as.integer(header),
...@@ -91,19 +93,19 @@ Predictor <- R6Class( ...@@ -91,19 +93,19 @@ Predictor <- R6Class(
as.integer(num_iteration), as.integer(num_iteration),
private$params, private$params,
lgb.c_str(tmp_filename)) lgb.c_str(tmp_filename))
# Get predictions from file # Get predictions from file
preds <- read.delim(tmp_filename, header = FALSE, seq = "\t") preds <- read.delim(tmp_filename, header = FALSE, seq = "\t")
num_row <- nrow(preds) num_row <- nrow(preds)
preds <- as.vector(t(preds)) preds <- as.vector(t(preds))
} else { } else {
# Not a file, we need to predict from R object # Not a file, we need to predict from R object
num_row <- nrow(data) num_row <- nrow(data)
npred <- 0L npred <- 0L
# Check number of predictions to do # Check number of predictions to do
npred <- lgb.call("LGBM_BoosterCalcNumPredict_R", npred <- lgb.call("LGBM_BoosterCalcNumPredict_R",
ret = npred, ret = npred,
...@@ -113,10 +115,10 @@ Predictor <- R6Class( ...@@ -113,10 +115,10 @@ Predictor <- R6Class(
as.integer(predleaf), as.integer(predleaf),
as.integer(predcontrib), as.integer(predcontrib),
as.integer(num_iteration)) as.integer(num_iteration))
# Pre-allocate empty vector # Pre-allocate empty vector
preds <- numeric(npred) preds <- numeric(npred)
# Check if data is a matrix # Check if data is a matrix
if (is.matrix(data)) { if (is.matrix(data)) {
preds <- lgb.call("LGBM_BoosterPredictForMat_R", preds <- lgb.call("LGBM_BoosterPredictForMat_R",
...@@ -130,8 +132,8 @@ Predictor <- R6Class( ...@@ -130,8 +132,8 @@ Predictor <- R6Class(
as.integer(predcontrib), as.integer(predcontrib),
as.integer(num_iteration), as.integer(num_iteration),
private$params) private$params)
} else if (is(data, "dgCMatrix")) { } else if (methods::is(data, "dgCMatrix")) {
if (length(data@p) > 2147483647) { if (length(data@p) > 2147483647) {
stop("Cannot support large CSC matrix") stop("Cannot support large CSC matrix")
} }
...@@ -150,44 +152,44 @@ Predictor <- R6Class( ...@@ -150,44 +152,44 @@ Predictor <- R6Class(
as.integer(predcontrib), as.integer(predcontrib),
as.integer(num_iteration), as.integer(num_iteration),
private$params) private$params)
} else { } else {
# Cannot predict on unknown class # Cannot predict on unknown class
# to-do: predict from lgb.Dataset # to-do: predict from lgb.Dataset
stop("predict: cannot predict on data of class ", sQuote(class(data))) stop("predict: cannot predict on data of class ", sQuote(class(data)))
} }
} }
# Check if number of rows is strange (not a multiple of the dataset rows) # Check if number of rows is strange (not a multiple of the dataset rows)
if (length(preds) %% num_row != 0) { if (length(preds) %% num_row != 0) {
stop("predict: prediction length ", sQuote(length(preds))," is not a multiple of nrows(data): ", sQuote(num_row)) stop("predict: prediction length ", sQuote(length(preds))," is not a multiple of nrows(data): ", sQuote(num_row))
} }
# Get number of cases per row # Get number of cases per row
npred_per_case <- length(preds) / num_row npred_per_case <- length(preds) / num_row
# Data reshaping # Data reshaping
if (predleaf | predcontrib) { if (predleaf | predcontrib) {
# Predict leaves only, reshaping is mandatory # Predict leaves only, reshaping is mandatory
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE) preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
} else if (reshape && npred_per_case > 1) { } else if (reshape && npred_per_case > 1) {
# Predict with data reshaping # Predict with data reshaping
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE) preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
} }
# Return predictions # Return predictions
return(preds) return(preds)
} }
), ),
private = list(handle = NULL, private = list(handle = NULL,
need_free_handle = FALSE, need_free_handle = FALSE,
......
#' Parse a LightGBM model json dump #' Parse a LightGBM model json dump
#' #'
#' Parse a LightGBM model json dump into a \code{data.table} structure. #' Parse a LightGBM model json dump into a \code{data.table} structure.
#' #'
#' @param model object of class \code{lgb.Booster} #' @param model object of class \code{lgb.Booster}
#' @param num_iteration number of iterations you want to predict with. NULL or #' @param num_iteration number of iterations you want to predict with. NULL or
#' <= 0 means use best iteration #' <= 0 means use best iteration
#' #'
#' @return #' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs. #' A \code{data.table} with detailed information about model trees' nodes and leafs.
#' #'
#' The columns of the \code{data.table} are: #' The columns of the \code{data.table} are:
#' #'
#' \itemize{ #' \itemize{
#' \item \code{tree_index}: ID of a tree in a model (integer) #' \item \code{tree_index}: ID of a tree in a model (integer)
#' \item \code{split_index}: ID of a node in a tree (integer) #' \item \code{split_index}: ID of a node in a tree (integer)
...@@ -28,11 +28,11 @@ ...@@ -28,11 +28,11 @@
#' \item \code{leaf_value}: Leaf value #' \item \code{leaf_value}: Leaf value
#' \item \code{leaf_count}: The number of observation collected by a leaf #' \item \code{leaf_count}: The number of observation collected by a leaf
#' } #' }
#' #'
#' @examples #' @examples
#' \dontrun{ #' \dontrun{
#' library(lightgbm) #' library(lightgbm)
#' #'
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train #' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label) #' dtrain <- lgb.Dataset(train$data, label = train$label)
...@@ -45,43 +45,45 @@ ...@@ -45,43 +45,45 @@
#' #'
#' tree_dt <- lgb.model.dt.tree(model) #' tree_dt <- lgb.model.dt.tree(model)
#' } #' }
#' #'
#' @importFrom magrittr %>% #' @importFrom magrittr %>%
#' @importFrom data.table := #' @importFrom data.table := data.table
#' @importFrom jsonlite fromJSON
#' @export #' @export
lgb.model.dt.tree <- function(model, num_iteration = NULL) { lgb.model.dt.tree <- function(model, num_iteration = NULL) {
# Dump json model first # Dump json model first
json_model <- lgb.dump(model, num_iteration = num_iteration) json_model <- lgb.dump(model, num_iteration = num_iteration)
# Parse json model second # Parse json model second
parsed_json_model <- jsonlite::fromJSON(json_model, parsed_json_model <- jsonlite::fromJSON(json_model,
simplifyVector = TRUE, simplifyVector = TRUE,
simplifyDataFrame = FALSE, simplifyDataFrame = FALSE,
simplifyMatrix = FALSE, simplifyMatrix = FALSE,
flatten = FALSE) flatten = FALSE)
# Parse tree model third # Parse tree model third
tree_list <- lapply(parsed_json_model$tree_info, single.tree.parse) tree_list <- lapply(parsed_json_model$tree_info, single.tree.parse)
# Combine into single data.table fourth # Combine into single data.table fourth
tree_dt <- data.table::rbindlist(tree_list, use.names = TRUE) tree_dt <- data.table::rbindlist(tree_list, use.names = TRUE)
# Lookup sequence # Lookup sequence
tree_dt[, split_feature := Lookup(split_feature, tree_dt[, split_feature := Lookup(split_feature,
seq.int(from = 0, to = parsed_json_model$max_feature_idx), seq.int(from = 0, to = parsed_json_model$max_feature_idx),
parsed_json_model$feature_names)] parsed_json_model$feature_names)]
# Return tree # Return tree
return(tree_dt) return(tree_dt)
} }
#' @importFrom data.table data.table rbindlist
single.tree.parse <- function(lgb_tree) { single.tree.parse <- function(lgb_tree) {
# Traverse tree function # Traverse tree function
pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) { pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) {
if (is.null(env)) { if (is.null(env)) {
# Setup initial default data.table with default types # Setup initial default data.table with default types
env <- new.env(parent = emptyenv()) env <- new.env(parent = emptyenv())
...@@ -103,10 +105,10 @@ single.tree.parse <- function(lgb_tree) { ...@@ -103,10 +105,10 @@ single.tree.parse <- function(lgb_tree) {
# start tree traversal # start tree traversal
pre_order_traversal(env, tree_node_leaf, current_depth, parent_index) pre_order_traversal(env, tree_node_leaf, current_depth, parent_index)
} else { } else {
# Check if split index is not null in leaf # Check if split index is not null in leaf
if (!is.null(tree_node_leaf$split_index)) { if (!is.null(tree_node_leaf$split_index)) {
# update data.table # update data.table
env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt, env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
c(tree_node_leaf[c("split_index", c(tree_node_leaf[c("split_index",
...@@ -121,7 +123,7 @@ single.tree.parse <- function(lgb_tree) { ...@@ -121,7 +123,7 @@ single.tree.parse <- function(lgb_tree) {
"node_parent" = parent_index)), "node_parent" = parent_index)),
use.names = TRUE, use.names = TRUE,
fill = TRUE) fill = TRUE)
# Traverse tree again both left and right # Traverse tree again both left and right
pre_order_traversal(env, pre_order_traversal(env,
tree_node_leaf$left_child, tree_node_leaf$left_child,
...@@ -131,9 +133,9 @@ single.tree.parse <- function(lgb_tree) { ...@@ -131,9 +133,9 @@ single.tree.parse <- function(lgb_tree) {
tree_node_leaf$right_child, tree_node_leaf$right_child,
current_depth = current_depth + 1L, current_depth = current_depth + 1L,
parent_index = tree_node_leaf$split_index) parent_index = tree_node_leaf$split_index)
} else if (!is.null(tree_node_leaf$leaf_index)) { } else if (!is.null(tree_node_leaf$leaf_index)) {
# update data.table # update data.table
env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt, env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
c(tree_node_leaf[c("leaf_index", c(tree_node_leaf[c("leaf_index",
...@@ -143,29 +145,30 @@ single.tree.parse <- function(lgb_tree) { ...@@ -143,29 +145,30 @@ single.tree.parse <- function(lgb_tree) {
"leaf_parent" = parent_index)), "leaf_parent" = parent_index)),
use.names = TRUE, use.names = TRUE,
fill = TRUE) fill = TRUE)
} }
} }
return(env$single_tree_dt) return(env$single_tree_dt)
} }
# Traverse structure # Traverse structure
single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure) single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
# Store index # Store index
single_tree_dt[, tree_index := lgb_tree$tree_index] single_tree_dt[, tree_index := lgb_tree$tree_index]
# Return tree # Return tree
return(single_tree_dt) return(single_tree_dt)
} }
#' @importFrom magrittr %>% extract inset
Lookup <- function(key, key_lookup, value_lookup, missing = NA) { Lookup <- function(key, key_lookup, value_lookup, missing = NA) {
# Match key by looked up key # Match key by looked up key
match(key, key_lookup) %>% match(key, key_lookup) %>%
magrittr::extract(value_lookup, .) %>% magrittr::extract(value_lookup, .) %>%
magrittr::inset(. , is.na(.), missing) magrittr::inset(. , is.na(.), missing)
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment