Unverified Commit 5dfe7168 authored by Michael Mayer's avatar Michael Mayer Committed by GitHub
Browse files

[R-package] Speed-up lgb.importance() (#6364)

parent 628e91a9
...@@ -90,6 +90,16 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) { ...@@ -90,6 +90,16 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
#' @importFrom data.table := data.table rbindlist #' @importFrom data.table := data.table rbindlist
.single_tree_parse <- function(lgb_tree) { .single_tree_parse <- function(lgb_tree) {
tree_info_cols <- c(
"split_index"
, "split_feature"
, "split_gain"
, "threshold"
, "decision_type"
, "default_left"
, "internal_value"
, "internal_count"
)
# Traverse tree function # Traverse tree function
pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) { pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) {
...@@ -97,7 +107,8 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) { ...@@ -97,7 +107,8 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
if (is.null(env)) { if (is.null(env)) {
# Setup initial default data.table with default types # Setup initial default data.table with default types
env <- new.env(parent = emptyenv()) env <- new.env(parent = emptyenv())
env$single_tree_dt <- data.table::data.table( env$single_tree_dt <- list()
env$single_tree_dt[[1L]] <- data.table::data.table(
tree_index = integer(0L) tree_index = integer(0L)
, depth = integer(0L) , depth = integer(0L)
, split_index = integer(0L) , split_index = integer(0L)
...@@ -127,19 +138,10 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) { ...@@ -127,19 +138,10 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
if (!is.null(tree_node_leaf$split_index)) { if (!is.null(tree_node_leaf$split_index)) {
# update data.table # update data.table
env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt, env$single_tree_dt[[length(env$single_tree_dt) + 1L]] <- c(
c(tree_node_leaf[c("split_index", tree_node_leaf[tree_info_cols]
"split_feature", , list("depth" = current_depth, "node_parent" = parent_index)
"split_gain", )
"threshold",
"decision_type",
"default_left",
"internal_value",
"internal_count")],
"depth" = current_depth,
"node_parent" = parent_index)),
use.names = TRUE,
fill = TRUE)
# Traverse tree again both left and right # Traverse tree again both left and right
pre_order_traversal( pre_order_traversal(
...@@ -154,31 +156,27 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) { ...@@ -154,31 +156,27 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
, current_depth = current_depth + 1L , current_depth = current_depth + 1L
, parent_index = tree_node_leaf$split_index , parent_index = tree_node_leaf$split_index
) )
} else if (!is.null(tree_node_leaf$leaf_index)) { } else if (!is.null(tree_node_leaf$leaf_index)) {
# update data.table # update list
env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt, env$single_tree_dt[[length(env$single_tree_dt) + 1L]] <- c(
c(tree_node_leaf[c("leaf_index", tree_node_leaf[c("leaf_index", "leaf_value", "leaf_count")]
"leaf_value", , list("depth" = current_depth, "leaf_parent" = parent_index)
"leaf_count")], )
"depth" = current_depth,
"leaf_parent" = parent_index)),
use.names = TRUE,
fill = TRUE)
} }
} }
return(env$single_tree_dt) return(env$single_tree_dt)
} }
# Traverse structure # Traverse structure and rowbind everything
single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure) single_tree_dt <- data.table::rbindlist(
pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
, use.names = TRUE
, fill = TRUE
)
# Store index # Store index
single_tree_dt[, tree_index := lgb_tree$tree_index] single_tree_dt[, tree_index := lgb_tree$tree_index]
return(single_tree_dt) return(single_tree_dt)
} }
NROUNDS <- 10L
MAX_DEPTH <- 3L
N <- nrow(iris)
X <- data.matrix(iris[2L:4L])
FEAT <- colnames(X)
NCLASS <- nlevels(iris[, 5L])
model_reg <- lgb.train(
params = list(
objective = "regression"
, num_threads = .LGB_MAX_THREADS
, max.depth = MAX_DEPTH
)
, data = lgb.Dataset(X, label = iris[, 1L])
, verbose = .LGB_VERBOSITY
, nrounds = NROUNDS
)
model_binary <- lgb.train(
params = list(
objective = "binary"
, num_threads = .LGB_MAX_THREADS
, max.depth = MAX_DEPTH
)
, data = lgb.Dataset(X, label = iris[, 5L] == "setosa")
, verbose = .LGB_VERBOSITY
, nrounds = NROUNDS
)
model_multiclass <- lgb.train(
params = list(
objective = "multiclass"
, num_threads = .LGB_MAX_THREADS
, max.depth = MAX_DEPTH
, num_classes = NCLASS
)
, data = lgb.Dataset(X, label = as.integer(iris[, 5L]) - 1L)
, verbose = .LGB_VERBOSITY
, nrounds = NROUNDS
)
model_rank <- lgb.train(
params = list(
objective = "lambdarank"
, num_threads = .LGB_MAX_THREADS
, max.depth = MAX_DEPTH
, lambdarank_truncation_level = 3L
)
, data = lgb.Dataset(
X
, label = as.integer(iris[, 1L] > 5.8)
, group = rep(10L, times = 15L)
)
, verbose = .LGB_VERBOSITY
, nrounds = NROUNDS
)
models <- list(
reg = model_reg
, bin = model_binary
, multi = model_multiclass
, rank = model_rank
)
for (model_name in names(models)) {
model <- models[[model_name]]
expected_n_trees <- NROUNDS
if (model_name == "multi") {
expected_n_trees <- NROUNDS * NCLASS
}
df <- as.data.frame(lgb.model.dt.tree(model))
df_list <- split(df, f = df$tree_index, drop = TRUE)
df_leaf <- df[!is.na(df$leaf_index), ]
df_internal <- df[is.na(df$leaf_index), ]
test_that("lgb.model.dt.tree() returns the right number of trees", {
expect_equal(length(unique(df$tree_index)), expected_n_trees)
})
test_that("num_iteration can return less trees", {
expect_equal(
length(unique(lgb.model.dt.tree(model, num_iteration = 2L)$tree_index))
, 2L * (if (model_name == "multi") NCLASS else 1L)
)
})
test_that("Tree index from lgb.model.dt.tree() is in 0:(NROUNS-1)", {
expect_equal(unique(df$tree_index), (0L:(expected_n_trees - 1L)))
})
test_that("Depth calculated from lgb.model.dt.tree() respects max.depth", {
expect_true(max(df$depth) <= MAX_DEPTH)
})
test_that("Each tree from lgb.model.dt.tree() has single root node", {
expect_equal(
unname(sapply(df_list, function(df) sum(df$depth == 0L)))
, rep(1L, expected_n_trees)
)
})
test_that("Each tree from lgb.model.dt.tree() has two depth 1 nodes", {
expect_equal(
unname(sapply(df_list, function(df) sum(df$depth == 1L)))
, rep(2L, expected_n_trees)
)
})
test_that("leaves from lgb.model.dt.tree() do not have split info", {
internal_node_cols <- c(
"split_index"
, "split_feature"
, "split_gain"
, "threshold"
, "decision_type"
, "default_left"
, "internal_value"
, "internal_count"
)
expect_true(all(is.na(df_leaf[internal_node_cols])))
})
test_that("leaves from lgb.model.dt.tree() have valid leaf info", {
expect_true(all(df_leaf$leaf_index %in% 0L:(2.0^MAX_DEPTH - 1.0)))
expect_true(all(is.finite(df_leaf$leaf_value)))
expect_true(all(df_leaf$leaf_count > 0L & df_leaf$leaf_count <= N))
})
test_that("non-leaves from lgb.model.dt.tree() do not have leaf info", {
leaf_node_cols <- c(
"leaf_index", "leaf_parent", "leaf_value", "leaf_count"
)
expect_true(all(is.na(df_internal[leaf_node_cols])))
})
test_that("non-leaves from lgb.model.dt.tree() have valid split info", {
expect_true(
all(
sapply(
split(df_internal, df_internal$tree_index),
function(x) all(x$split_index %in% 0L:(nrow(x) - 1L))
)
)
)
expect_true(all(df_internal$split_feature %in% FEAT))
num_cols <- c("split_gain", "threshold", "internal_value")
expect_true(all(is.finite(unlist(df_internal[, num_cols]))))
# range of decision type?
expect_true(all(df_internal$default_left %in% c(TRUE, FALSE)))
counts <- df_internal$internal_count
expect_true(all(counts > 1L & counts <= N))
})
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment