lgb.model.dt.tree.R 6.23 KB
Newer Older
1
#' Parse a LightGBM model json dump
2
#' 
3
#' Parse a LightGBM model json dump into a \code{data.table} structure.
4
#' 
5
#' @param model object of class \code{lgb.Booster}
6
#' 
7
8
#' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs.
9
#' 
10
#' The columns of the \code{data.table} are:
11
#' 
12
13
14
15
#' \itemize{
#'  \item \code{tree_index}: ID of a tree in a model (integer)
#'  \item \code{split_index}: ID of a node in a tree (integer)
#'  \item \code{split_feature}: for a node, it's a feature name (character);
16
#'                              for a leaf, it simply labels it as \code{"NA"}
17
18
19
20
21
#'  \item \code{node_parent}: ID of the parent node for current node (integer)
#'  \item \code{leaf_index}: ID of a leaf in a tree (integer)
#'  \item \code{leaf_parent}: ID of the parent node for current leaf (integer)
#'  \item \code{split_gain}: Split gain of a node
#'  \item \code{threshold}: Spliting threshold value of a node
22
#'  \item \code{decision_type}: Decision type of a node
23
24
25
26
27
#'  \item \code{internal_value}: Node value
#'  \item \code{internal_count}: The number of observation collected by a node
#'  \item \code{leaf_value}: Leaf value
#'  \item \code{leaf_count}: The number of observation collected by a leaf
#' }
28
#' 
29
#' @examples
30
31
32
33
#' \dontrun{
#' library(lightgbm)
#' 
#' data(agaricus.train, package = "lightgbm")
34
35
36
37
38
39
40
41
42
43
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#'
#' params = list(objective = "binary",
#'               learning_rate = 0.01, num_leaves = 63, max_depth = -1,
#'               min_data_in_leaf = 1, min_sum_hessian_in_leaf = 1)
#'               model <- lgb.train(params, dtrain, 20)
#' model <- lgb.train(params, dtrain, 20)
#'
#' tree_dt <- lgb.model.dt.tree(model)
44
45
#' }
#' 
46
47
48
49
#' @importFrom magrittr %>%
#' @importFrom data.table :=
#' @export
lgb.model.dt.tree <- function(model, num_iteration = NULL) {
50
51
  
  # Dump json model first
52
  json_model <- lgb.dump(model, num_iteration = num_iteration)
53
54
  
  # Parse json model second
55
56
57
58
59
  parsed_json_model <- jsonlite::fromJSON(json_model,
                                          simplifyVector = TRUE,
                                          simplifyDataFrame = FALSE,
                                          simplifyMatrix = FALSE,
                                          flatten = FALSE)
60
61
  
  # Parse tree model third
62
  tree_list <- lapply(parsed_json_model$tree_info, single.tree.parse)
63
64
  
  # Combine into single data.table fourth
65
  tree_dt <- data.table::rbindlist(tree_list, use.names = TRUE)
66
67
  
  # Lookup sequence
68
69
70
  tree_dt[, split_feature := Lookup(split_feature,
                                    seq(0, parsed_json_model$max_feature_idx, by = 1),
                                    parsed_json_model$feature_names)]
71
72
  
  # Return tree
73
  return(tree_dt)
74
  
75
76
77
}

single.tree.parse <- function(lgb_tree) {
78
79
  
  # Setup initial default data.table with default types
80
  single_tree_dt <- data.table::data.table(tree_index = integer(0),
81
82
83
84
85
86
87
88
89
90
91
92
93
94
                                           split_index = integer(0),
                                           split_feature = integer(0),
                                           node_parent = integer(0),
                                           leaf_index = integer(0),
                                           leaf_parent = integer(0),
                                           split_gain = numeric(0),
                                           threshold = numeric(0),
                                           decision_type = character(0),
                                           internal_value = integer(0),
                                           internal_count = integer(0),
                                           leaf_value = integer(0),
                                           leaf_count = integer(0))
  
  # Traverse tree function
95
  pre_order_traversal <- function(tree_node_leaf, parent_index = NA) {
96
97
    
    # Check if split index is not null in leaf
98
    if (!is.null(tree_node_leaf$split_index)) {
99
100
      
      # Overwrite data.table - this should be switched to an envir in the future
101
      single_tree_dt <<- data.table::rbindlist(l = list(single_tree_dt,
102
103
104
105
106
107
108
                                                        c(tree_node_leaf[c("split_index",
                                                                           "split_feature",
                                                                           "split_gain",
                                                                           "threshold",
                                                                           "decision_type",
                                                                           "internal_value",
                                                                           "internal_count")],
109
                                                          "node_parent" = parent_index)),
110
111
112
113
114
115
116
117
118
                                               use.names = TRUE,
                                               fill = TRUE)
      
      # Traverse tree again both left and right
      pre_order_traversal(tree_node_leaf$left_child,
                          parent_index = tree_node_leaf$split_index)
      pre_order_traversal(tree_node_leaf$right_child,
                          parent_index = tree_node_leaf$split_index)
      
119
    } else if (!is.null(tree_node_leaf$leaf_index)) {
120
121
      
      # Overwrite data.table - this should be switched to an envir in the future
122
      single_tree_dt <<- data.table::rbindlist(l = list(single_tree_dt,
123
124
125
126
127
128
                                                        tree_node_leaf[c("leaf_index",
                                                                         "leaf_value",
                                                                         "leaf_count")]),
                                               use.names = TRUE,
                                               fill = TRUE)
      
129
    }
130
    
131
  }
132
133
  
  # Traverse structure
134
  pre_order_traversal(lgb_tree$tree_structure)
135
136
  
  # Store index
137
  single_tree_dt[, tree_index := lgb_tree$tree_index]
138
139
  
  # Return tree
140
  return(single_tree_dt)
141
  
142
143
144
}

Lookup <- function(key, key_lookup, value_lookup, missing = NA) {
145
146
  
  # Match key by looked up key
147
148
149
  match(key, key_lookup) %>%
    magrittr::extract(value_lookup, .) %>%
    magrittr::inset(. , is.na(.), missing)
150
  
151
}