Commit 83be8e97 authored by Yachen Yan's avatar Yachen Yan Committed by Guolin Ke
Browse files

Fix R Tree Parse (#864)

Add Depth Column
Refine Code Structure
parent ae6ff288
...@@ -76,8 +76,14 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) { ...@@ -76,8 +76,14 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
single.tree.parse <- function(lgb_tree) { single.tree.parse <- function(lgb_tree) {
# Traverse tree function
pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) {
if (is.null(env)) {
# Setup initial default data.table with default types # Setup initial default data.table with default types
single_tree_dt <- data.table::data.table(tree_index = integer(0), env <- new.env(parent = emptyenv())
env$single_tree_dt <- data.table::data.table(tree_index = integer(0),
depth = integer(0),
split_index = integer(0), split_index = integer(0),
split_feature = integer(0), split_feature = integer(0),
node_parent = integer(0), node_parent = integer(0),
...@@ -90,15 +96,15 @@ single.tree.parse <- function(lgb_tree) { ...@@ -90,15 +96,15 @@ single.tree.parse <- function(lgb_tree) {
internal_count = integer(0), internal_count = integer(0),
leaf_value = integer(0), leaf_value = integer(0),
leaf_count = integer(0)) leaf_count = integer(0))
# start tree traversal
# Traverse tree function pre_order_traversal(env, tree_node_leaf, current_depth, parent_index)
pre_order_traversal <- function(tree_node_leaf, parent_index = NA) { } else {
# Check if split index is not null in leaf # Check if split index is not null in leaf
if (!is.null(tree_node_leaf$split_index)) { if (!is.null(tree_node_leaf$split_index)) {
# Overwrite data.table - this should be switched to an envir in the future # update data.table
single_tree_dt <<- data.table::rbindlist(l = list(single_tree_dt, env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
c(tree_node_leaf[c("split_index", c(tree_node_leaf[c("split_index",
"split_feature", "split_feature",
"split_gain", "split_gain",
...@@ -106,32 +112,41 @@ single.tree.parse <- function(lgb_tree) { ...@@ -106,32 +112,41 @@ single.tree.parse <- function(lgb_tree) {
"decision_type", "decision_type",
"internal_value", "internal_value",
"internal_count")], "internal_count")],
"depth" = current_depth,
"node_parent" = parent_index)), "node_parent" = parent_index)),
use.names = TRUE, use.names = TRUE,
fill = TRUE) fill = TRUE)
# Traverse tree again both left and right # Traverse tree again both left and right
pre_order_traversal(tree_node_leaf$left_child, pre_order_traversal(env,
tree_node_leaf$left_child,
current_depth = current_depth + 1L,
parent_index = tree_node_leaf$split_index) parent_index = tree_node_leaf$split_index)
pre_order_traversal(tree_node_leaf$right_child, pre_order_traversal(env,
tree_node_leaf$right_child,
current_depth = current_depth + 1L,
parent_index = tree_node_leaf$split_index) parent_index = tree_node_leaf$split_index)
} else if (!is.null(tree_node_leaf$leaf_index)) { } else if (!is.null(tree_node_leaf$leaf_index)) {
# Overwrite data.table - this should be switched to an envir in the future # update data.table
single_tree_dt <<- data.table::rbindlist(l = list(single_tree_dt, env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
tree_node_leaf[c("leaf_index", c(tree_node_leaf[c("leaf_index",
"leaf_value", "leaf_value",
"leaf_count")]), "leaf_count")],
"depth" = current_depth,
"leaf_parent" = parent_index)),
use.names = TRUE, use.names = TRUE,
fill = TRUE) fill = TRUE)
} }
} }
return(env$single_tree_dt)
}
# Traverse structure # Traverse structure
pre_order_traversal(lgb_tree$tree_structure) single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
# Store index # Store index
single_tree_dt[, tree_index := lgb_tree$tree_index] single_tree_dt[, tree_index := lgb_tree$tree_index]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment