"src/vscode:/vscode.git/clone" did not exist on "0f3d90e7b0afd39733a5c8aefe772425aad764cc"
Commit 21575cb2 authored by James Lamb's avatar James Lamb Committed by Laurae
Browse files

updated R dependencies (#1619)

parent b2240721
......@@ -23,23 +23,23 @@ URL: https://github.com/Microsoft/LightGBM
BugReports: https://github.com/Microsoft/LightGBM/issues
VignetteBuilder: knitr
Suggests:
Ckmeans.1d.dp (>= 3.3.1),
DiagrammeR (>= 0.8.1),
ggplot2 (>= 1.0.1),
igraph (>= 1.0.1),
knitr,
rmarkdown,
ggplot2 (>= 1.0.1),
DiagrammeR (>= 0.8.1),
Ckmeans.1d.dp (>= 3.3.1),
vcd (>= 1.3),
stringi (>= 0.5.2),
testthat,
igraph (>= 1.0.1),
stringi (>= 0.5.2)
vcd (>= 1.3)
Depends:
R (>= 3.0),
R6 (>= 2.0)
Imports:
graphics,
methods,
Matrix (>= 1.1-0),
data.table (>= 1.9.6),
graphics,
jsonlite (>= 1.0),
magrittr (>= 1.5),
jsonlite (>= 1.0)
Matrix (>= 1.1-0),
methods
RoxygenNote: 6.0.1
......@@ -38,9 +38,15 @@ export(slice)
import(methods)
importFrom(R6,R6Class)
importFrom(data.table,":=")
importFrom(data.table,data.table)
importFrom(data.table,rbindlist)
importFrom(data.table,set)
importFrom(graphics,barplot)
importFrom(graphics,par)
importFrom(jsonlite,fromJSON)
importFrom(magrittr,"%>%")
importFrom(magrittr,"%T>%")
importFrom(magrittr,extract)
importFrom(magrittr,inset)
importFrom(methods,is)
useDynLib(lib_lightgbm)
This diff is collapsed.
#' @importFrom methods is
Predictor <- R6Class(
classname = "lgb.Predictor",
cloneable = FALSE,
public = list(
# Finalize will free up the handles
finalize = function() {
# Check the need for freeing handle
if (private$need_free_handle && !lgb.is.null.handle(private$handle)) {
# Freeing up handle
lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
private$handle <- NULL
}
},
# Initialize will create a starter model
initialize = function(modelfile, ...) {
params <- list(...)
private$params <- lgb.params2str(params)
# Create new lgb handle
handle <- 0.0
# Check if handle is a character
if (is.character(modelfile)) {
# Create handle on it
handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R", ret = handle, lgb.c_str(modelfile))
private$need_free_handle <- TRUE
} else if (is(modelfile, "lgb.Booster.handle")) {
} else if (methods::is(modelfile, "lgb.Booster.handle")) {
# Check if model file is a booster handle already
handle <- modelfile
private$need_free_handle <- FALSE
} else {
# Model file is unknown
stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
}
# Override class and store it
class(handle) <- "lgb.Booster.handle"
private$handle <- handle
},
# Get current iteration
current_iter = function() {
cur_iter <- 0L
lgb.call("LGBM_BoosterGetCurrentIteration_R", ret = cur_iter, private$handle)
},
# Predict from data
predict = function(data,
num_iteration = NULL,
......@@ -66,22 +68,22 @@ Predictor <- R6Class(
predcontrib = FALSE,
header = FALSE,
reshape = FALSE) {
# Check if number of iterations is existing - if not, then set it to -1 (use all)
if (is.null(num_iteration)) {
num_iteration <- -1
}
# Set temporary variable
num_row <- 0L
# Check if data is a file name
if (is.character(data)) {
# Data is a filename, create a temporary file with a "lightgbm_" pattern in it
tmp_filename <- tempfile(pattern = "lightgbm_")
on.exit(unlink(tmp_filename), add = TRUE)
# Predict from temporary file
lgb.call("LGBM_BoosterPredictForFile_R", ret = NULL, private$handle, data,
as.integer(header),
......@@ -91,19 +93,19 @@ Predictor <- R6Class(
as.integer(num_iteration),
private$params,
lgb.c_str(tmp_filename))
# Get predictions from file
preds <- read.delim(tmp_filename, header = FALSE, seq = "\t")
num_row <- nrow(preds)
preds <- as.vector(t(preds))
} else {
# Not a file, we need to predict from R object
num_row <- nrow(data)
npred <- 0L
# Check number of predictions to do
npred <- lgb.call("LGBM_BoosterCalcNumPredict_R",
ret = npred,
......@@ -113,10 +115,10 @@ Predictor <- R6Class(
as.integer(predleaf),
as.integer(predcontrib),
as.integer(num_iteration))
# Pre-allocate empty vector
preds <- numeric(npred)
# Check if data is a matrix
if (is.matrix(data)) {
preds <- lgb.call("LGBM_BoosterPredictForMat_R",
......@@ -130,8 +132,8 @@ Predictor <- R6Class(
as.integer(predcontrib),
as.integer(num_iteration),
private$params)
} else if (is(data, "dgCMatrix")) {
} else if (methods::is(data, "dgCMatrix")) {
if (length(data@p) > 2147483647) {
stop("Cannot support large CSC matrix")
}
......@@ -150,44 +152,44 @@ Predictor <- R6Class(
as.integer(predcontrib),
as.integer(num_iteration),
private$params)
} else {
# Cannot predict on unknown class
# to-do: predict from lgb.Dataset
stop("predict: cannot predict on data of class ", sQuote(class(data)))
}
}
# Check if number of rows is strange (not a multiple of the dataset rows)
if (length(preds) %% num_row != 0) {
stop("predict: prediction length ", sQuote(length(preds))," is not a multiple of nrows(data): ", sQuote(num_row))
}
# Get number of cases per row
npred_per_case <- length(preds) / num_row
# Data reshaping
if (predleaf | predcontrib) {
# Predict leaves only, reshaping is mandatory
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
} else if (reshape && npred_per_case > 1) {
# Predict with data reshaping
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
}
# Return predictions
return(preds)
}
),
private = list(handle = NULL,
need_free_handle = FALSE,
......
#' Parse a LightGBM model json dump
#'
#'
#' Parse a LightGBM model json dump into a \code{data.table} structure.
#'
#'
#' @param model object of class \code{lgb.Booster}
#' @param num_iteration number of iterations you want to predict with. NULL or
#' @param num_iteration number of iterations you want to predict with. NULL or
#' <= 0 means use best iteration
#'
#'
#' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs.
#'
#'
#' The columns of the \code{data.table} are:
#'
#'
#' \itemize{
#' \item \code{tree_index}: ID of a tree in a model (integer)
#' \item \code{split_index}: ID of a node in a tree (integer)
......@@ -28,11 +28,11 @@
#' \item \code{leaf_value}: Leaf value
#' \item \code{leaf_count}: The number of observation collected by a leaf
#' }
#'
#'
#' @examples
#' \dontrun{
#' library(lightgbm)
#'
#'
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
......@@ -45,43 +45,45 @@
#'
#' tree_dt <- lgb.model.dt.tree(model)
#' }
#'
#'
#' @importFrom magrittr %>%
#' @importFrom data.table :=
#' @importFrom data.table := data.table
#' @importFrom jsonlite fromJSON
#' @export
lgb.model.dt.tree <- function(model, num_iteration = NULL) {
# Dump json model first
json_model <- lgb.dump(model, num_iteration = num_iteration)
# Parse json model second
parsed_json_model <- jsonlite::fromJSON(json_model,
simplifyVector = TRUE,
simplifyDataFrame = FALSE,
simplifyMatrix = FALSE,
flatten = FALSE)
# Parse tree model third
tree_list <- lapply(parsed_json_model$tree_info, single.tree.parse)
# Combine into single data.table fourth
tree_dt <- data.table::rbindlist(tree_list, use.names = TRUE)
# Lookup sequence
tree_dt[, split_feature := Lookup(split_feature,
seq.int(from = 0, to = parsed_json_model$max_feature_idx),
parsed_json_model$feature_names)]
# Return tree
return(tree_dt)
}
#' @importFrom data.table data.table rbindlist
single.tree.parse <- function(lgb_tree) {
# Traverse tree function
pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) {
if (is.null(env)) {
# Setup initial default data.table with default types
env <- new.env(parent = emptyenv())
......@@ -103,10 +105,10 @@ single.tree.parse <- function(lgb_tree) {
# start tree traversal
pre_order_traversal(env, tree_node_leaf, current_depth, parent_index)
} else {
# Check if split index is not null in leaf
if (!is.null(tree_node_leaf$split_index)) {
# update data.table
env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
c(tree_node_leaf[c("split_index",
......@@ -121,7 +123,7 @@ single.tree.parse <- function(lgb_tree) {
"node_parent" = parent_index)),
use.names = TRUE,
fill = TRUE)
# Traverse tree again both left and right
pre_order_traversal(env,
tree_node_leaf$left_child,
......@@ -131,9 +133,9 @@ single.tree.parse <- function(lgb_tree) {
tree_node_leaf$right_child,
current_depth = current_depth + 1L,
parent_index = tree_node_leaf$split_index)
} else if (!is.null(tree_node_leaf$leaf_index)) {
# update data.table
env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
c(tree_node_leaf[c("leaf_index",
......@@ -143,29 +145,30 @@ single.tree.parse <- function(lgb_tree) {
"leaf_parent" = parent_index)),
use.names = TRUE,
fill = TRUE)
}
}
return(env$single_tree_dt)
}
# Traverse structure
single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
# Store index
single_tree_dt[, tree_index := lgb_tree$tree_index]
# Return tree
return(single_tree_dt)
}
#' @importFrom magrittr %>% extract inset
Lookup <- function(key, key_lookup, value_lookup, missing = NA) {
# Match key by looked up key
match(key, key_lookup) %>%
magrittr::extract(value_lookup, .) %>%
magrittr::inset(. , is.na(.), missing)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment