Commit bc628ee4 authored by Yachen Yan's avatar Yachen Yan Committed by Guolin Ke
Browse files

Parse a LightGBM model json dump (#253)

* Add lgb.Booster parsing function

* Add data.table, magrittr, jsonlite to Imports
parent 38ea6f61
...@@ -10,8 +10,8 @@ Description: LightGBM is a gradient boosting framework that uses tree based lear ...@@ -10,8 +10,8 @@ Description: LightGBM is a gradient boosting framework that uses tree based lear
1.Faster training speed and higher efficiency. 1.Faster training speed and higher efficiency.
2.Lower memory usage. 2.Lower memory usage.
3.Better accuracy. 3.Better accuracy.
4.Parallel learning supported 4.Parallel learning supported.
5. Capable of handling large-scale data 5. Capable of handling large-scale data.
License: The MIT License (MIT) | file LICENSE License: The MIT License (MIT) | file LICENSE
URL: https://github.com/Microsoft/LightGBM URL: https://github.com/Microsoft/LightGBM
BugReports: https://github.com/Microsoft/LightGBM/issues BugReports: https://github.com/Microsoft/LightGBM/issues
...@@ -25,14 +25,14 @@ Suggests: ...@@ -25,14 +25,14 @@ Suggests:
vcd (>= 1.3), vcd (>= 1.3),
testthat, testthat,
igraph (>= 1.0.1), igraph (>= 1.0.1),
methods,
data.table (>= 1.9.6),
magrittr (>= 1.5),
stringi (>= 0.5.2) stringi (>= 0.5.2)
Depends: Depends:
R (>= 3.0), R (>= 3.0),
R6 R6
Imports: Imports:
methods,
Matrix (>= 1.1-0), Matrix (>= 1.1-0),
methods data.table (>= 1.9.6),
RoxygenNote: 5.0.1 magrittr (>= 1.5),
\ No newline at end of file jsonlite
RoxygenNote: 5.0.1
...@@ -18,6 +18,7 @@ export(lgb.cv) ...@@ -18,6 +18,7 @@ export(lgb.cv)
export(lgb.dump) export(lgb.dump)
export(lgb.get.eval.result) export(lgb.get.eval.result)
export(lgb.load) export(lgb.load)
export(lgb.model.dt.tree)
export(lgb.save) export(lgb.save)
export(lgb.train) export(lgb.train)
export(lightgbm) export(lightgbm)
...@@ -25,4 +26,6 @@ export(setinfo) ...@@ -25,4 +26,6 @@ export(setinfo)
export(slice) export(slice)
import(methods) import(methods)
importFrom(R6,R6Class) importFrom(R6,R6Class)
importFrom(data.table,":=")
importFrom(magrittr,"%>%")
useDynLib(lightgbm) useDynLib(lightgbm)
#' Parse a LightGBM model json dump
#'
#' Parse a LightGBM model json dump into a \code{data.table} structure.
#'
#' @param model object of class \code{lgb.Booster}
#'
#' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs.
#'
#' The columns of the \code{data.table} are:
#'
#' \itemize{
#' \item \code{tree_index}: ID of a tree in a model (integer)
#' \item \code{split_index}: ID of a node in a tree (integer)
#' \item \code{split_feature}: for a node, it's a feature name (character);
#' for a leaf, it simply labels it as \code{'NA'}
#' \item \code{node_parent}: ID of the parent node for current node (integer)
#' \item \code{leaf_index}: ID of a leaf in a tree (integer)
#' \item \code{leaf_parent}: ID of the parent node for current leaf (integer)
#' \item \code{split_gain}: Split gain of a node
#' \item \code{threshold}: Spliting threshold value of a node
#' \item \code{decision_type}: Decision type of a node
#' \item \code{internal_value}: Node value
#' \item \code{internal_count}: The number of observation collected by a node
#' \item \code{leaf_value}: Leaf value
#' \item \code{leaf_count}: The number of observation collected by a leaf
#' }
#'
#' @examples
#'
#' data(agaricus.train, package = 'lightgbm')
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#'
#' params = list(objective = "binary",
#' learning_rate = 0.01, num_leaves = 63, max_depth = -1,
#' min_data_in_leaf = 1, min_sum_hessian_in_leaf = 1)
#' model <- lgb.train(params, dtrain, 20)
#' model <- lgb.train(params, dtrain, 20)
#'
#' tree_dt <- lgb.model.dt.tree(model)
#'
#' @importFrom magrittr %>%
#' @importFrom data.table :=
#' @export
lgb.model.dt.tree <- function(model, num_iteration = NULL) {
json_model <- lgb.dump(model, num_iteration = num_iteration)
parsed_json_model <- jsonlite::fromJSON(json_model,
simplifyVector = TRUE,
simplifyDataFrame = FALSE,
simplifyMatrix = FALSE,
flatten = FALSE)
tree_list <- lapply(parsed_json_model$tree_info, single.tree.parse)
tree_dt <- data.table::rbindlist(tree_list, use.names = TRUE)
tree_dt[, split_feature := Lookup(split_feature,
seq(0, parsed_json_model$max_feature_idx, by = 1),
parsed_json_model$feature_names)]
return(tree_dt)
}
single.tree.parse <- function(lgb_tree) {
single_tree_dt <- data.table::data.table(tree_index = integer(0),
split_index = integer(0), split_feature = integer(0), node_parent = integer(0),
leaf_index = integer(0), leaf_parent = integer(0),
split_gain = numeric(0), threshold = numeric(0), decision_type = character(0),
internal_value = integer(0), internal_count = integer(0),
leaf_value = integer(0), leaf_count = integer(0))
pre_order_traversal <- function(tree_node_leaf, parent_index = NA) {
if (!is.null(tree_node_leaf$split_index)) {
single_tree_dt <<- data.table::rbindlist(l = list(single_tree_dt,
c(tree_node_leaf[c("split_index", "split_feature",
"split_gain", "threshold", "decision_type",
"internal_value", "internal_count")],
"node_parent" = parent_index)),
use.names = TRUE, fill = TRUE)
pre_order_traversal(tree_node_leaf$left_child, parent_index = tree_node_leaf$split_index)
pre_order_traversal(tree_node_leaf$right_child, parent_index = tree_node_leaf$split_index)
} else if (!is.null(tree_node_leaf$leaf_index)) {
single_tree_dt <<- data.table::rbindlist(l = list(single_tree_dt,
tree_node_leaf[c("leaf_index", "leaf_parent",
"leaf_value", "leaf_count")]),
use.names = TRUE, fill = TRUE)
}
}
pre_order_traversal(lgb_tree$tree_structure)
single_tree_dt[, tree_index := lgb_tree$tree_index]
return(single_tree_dt)
}
Lookup <- function(key, key_lookup, value_lookup, missing = NA) {
match(key, key_lookup) %>%
magrittr::extract(value_lookup, .) %>%
magrittr::inset(. , is.na(.), missing)
}
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/lgb.model.dt.tree.R
\name{lgb.model.dt.tree}
\alias{lgb.model.dt.tree}
\title{Parse a LightGBM model json dump}
\usage{
lgb.model.dt.tree(model, num_iteration = NULL)
}
\arguments{
\item{model}{object of class \code{lgb.Booster}}
}
\value{
A \code{data.table} with detailed information about model trees' nodes and leafs.
The columns of the \code{data.table} are:
\itemize{
\item \code{tree_index}: ID of a tree in a model (integer)
\item \code{split_index}: ID of a node in a tree (integer)
\item \code{split_feature}: for a node, it's a feature name (character);
for a leaf, it simply labels it as \code{'NA'}
\item \code{node_parent}: ID of the parent node for current node (integer)
\item \code{leaf_index}: ID of a leaf in a tree (integer)
\item \code{leaf_parent}: ID of the parent node for current leaf (integer)
\item \code{split_gain}: Split gain of a node
\item \code{threshold}: Spliting threshold value of a node
\item \code{decision_type}: Decision type of a node
\item \code{internal_value}: Node value
\item \code{internal_count}: The number of observation collected by a node
\item \code{leaf_value}: Leaf value
\item \code{leaf_count}: The number of observation collected by a leaf
}
}
\description{
Parse a LightGBM model json dump into a \code{data.table} structure.
}
\examples{
data(agaricus.train, package = 'lightgbm')
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
params = list(objective = "binary",
learning_rate = 0.01, num_leaves = 63, max_depth = -1,
min_data_in_leaf = 1, min_sum_hessian_in_leaf = 1)
model <- lgb.train(params, dtrain, 20)
model <- lgb.train(params, dtrain, 20)
tree_dt <- lgb.model.dt.tree(model)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment