Commit 83be8e97 authored by Yachen Yan's avatar Yachen Yan Committed by Guolin Ke
Browse files

Fix R Tree Parse (#864)

Add Depth Column
Refine Code Structure
parent ae6ff288
......@@ -76,62 +76,77 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
single.tree.parse <- function(lgb_tree) {
# Setup initial default data.table with default types
single_tree_dt <- data.table::data.table(tree_index = integer(0),
split_index = integer(0),
split_feature = integer(0),
node_parent = integer(0),
leaf_index = integer(0),
leaf_parent = integer(0),
split_gain = numeric(0),
threshold = numeric(0),
decision_type = character(0),
internal_value = integer(0),
internal_count = integer(0),
leaf_value = integer(0),
leaf_count = integer(0))
# Traverse tree function
pre_order_traversal <- function(tree_node_leaf, parent_index = NA) {
pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) {
# Check if split index is not null in leaf
if (!is.null(tree_node_leaf$split_index)) {
# Overwrite data.table - this should be switched to an envir in the future
single_tree_dt <<- data.table::rbindlist(l = list(single_tree_dt,
c(tree_node_leaf[c("split_index",
"split_feature",
"split_gain",
"threshold",
"decision_type",
"internal_value",
"internal_count")],
"node_parent" = parent_index)),
use.names = TRUE,
fill = TRUE)
# Traverse tree again both left and right
pre_order_traversal(tree_node_leaf$left_child,
parent_index = tree_node_leaf$split_index)
pre_order_traversal(tree_node_leaf$right_child,
parent_index = tree_node_leaf$split_index)
if (is.null(env)) {
# Setup initial default data.table with default types
env <- new.env(parent = emptyenv())
env$single_tree_dt <- data.table::data.table(tree_index = integer(0),
depth = integer(0),
split_index = integer(0),
split_feature = integer(0),
node_parent = integer(0),
leaf_index = integer(0),
leaf_parent = integer(0),
split_gain = numeric(0),
threshold = numeric(0),
decision_type = character(0),
internal_value = integer(0),
internal_count = integer(0),
leaf_value = integer(0),
leaf_count = integer(0))
# start tree traversal
pre_order_traversal(env, tree_node_leaf, current_depth, parent_index)
} else {
} else if (!is.null(tree_node_leaf$leaf_index)) {
# Overwrite data.table - this should be switched to an envir in the future
single_tree_dt <<- data.table::rbindlist(l = list(single_tree_dt,
tree_node_leaf[c("leaf_index",
"leaf_value",
"leaf_count")]),
use.names = TRUE,
fill = TRUE)
# Check if split index is not null in leaf
if (!is.null(tree_node_leaf$split_index)) {
# update data.table
env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
c(tree_node_leaf[c("split_index",
"split_feature",
"split_gain",
"threshold",
"decision_type",
"internal_value",
"internal_count")],
"depth" = current_depth,
"node_parent" = parent_index)),
use.names = TRUE,
fill = TRUE)
# Traverse tree again both left and right
pre_order_traversal(env,
tree_node_leaf$left_child,
current_depth = current_depth + 1L,
parent_index = tree_node_leaf$split_index)
pre_order_traversal(env,
tree_node_leaf$right_child,
current_depth = current_depth + 1L,
parent_index = tree_node_leaf$split_index)
} else if (!is.null(tree_node_leaf$leaf_index)) {
# update data.table
env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
c(tree_node_leaf[c("leaf_index",
"leaf_value",
"leaf_count")],
"depth" = current_depth,
"leaf_parent" = parent_index)),
use.names = TRUE,
fill = TRUE)
}
}
return(env$single_tree_dt)
}
# Traverse structure
pre_order_traversal(lgb_tree$tree_structure)
single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
# Store index
single_tree_dt[, tree_index := lgb_tree$tree_index]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment