lgb.model.dt.tree.R 7.44 KB
Newer Older
1
#' Parse a LightGBM model json dump
2
#' 
3
#' Parse a LightGBM model json dump into a \code{data.table} structure.
4
#' 
5
#' @param model object of class \code{lgb.Booster}
6
#' 
7
8
#' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs.
9
#' 
10
#' The columns of the \code{data.table} are:
11
#' 
12
13
14
15
#' \itemize{
#'  \item \code{tree_index}: ID of a tree in a model (integer)
#'  \item \code{split_index}: ID of a node in a tree (integer)
#'  \item \code{split_feature}: for a node, it's a feature name (character);
16
#'                              for a leaf, it simply labels it as \code{"NA"}
17
18
19
20
21
#'  \item \code{node_parent}: ID of the parent node for current node (integer)
#'  \item \code{leaf_index}: ID of a leaf in a tree (integer)
#'  \item \code{leaf_parent}: ID of the parent node for current leaf (integer)
#'  \item \code{split_gain}: Split gain of a node
#'  \item \code{threshold}: Spliting threshold value of a node
22
#'  \item \code{decision_type}: Decision type of a node
23
#'  \item \code{default_left}: Determine how to handle NA value, TRUE -> Left, FALSE -> Right
24
25
26
27
28
#'  \item \code{internal_value}: Node value
#'  \item \code{internal_count}: The number of observation collected by a node
#'  \item \code{leaf_value}: Leaf value
#'  \item \code{leaf_count}: The number of observation collected by a leaf
#' }
29
#' 
30
#' @examples
31
32
33
34
#' \dontrun{
#' library(lightgbm)
#' 
#' data(agaricus.train, package = "lightgbm")
35
36
37
38
39
40
41
42
43
44
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#'
#' params = list(objective = "binary",
#'               learning_rate = 0.01, num_leaves = 63, max_depth = -1,
#'               min_data_in_leaf = 1, min_sum_hessian_in_leaf = 1)
#'               model <- lgb.train(params, dtrain, 20)
#' model <- lgb.train(params, dtrain, 20)
#'
#' tree_dt <- lgb.model.dt.tree(model)
45
46
#' }
#' 
47
48
49
50
#' @importFrom magrittr %>%
#' @importFrom data.table :=
#' @export
lgb.model.dt.tree <- function(model, num_iteration = NULL) {
51
52
  
  # Dump json model first
53
  json_model <- lgb.dump(model, num_iteration = num_iteration)
54
55
  
  # Parse json model second
56
57
58
59
60
  parsed_json_model <- jsonlite::fromJSON(json_model,
                                          simplifyVector = TRUE,
                                          simplifyDataFrame = FALSE,
                                          simplifyMatrix = FALSE,
                                          flatten = FALSE)
61
62
  
  # Parse tree model third
63
  tree_list <- lapply(parsed_json_model$tree_info, single.tree.parse)
64
65
  
  # Combine into single data.table fourth
66
  tree_dt <- data.table::rbindlist(tree_list, use.names = TRUE)
67
68
  
  # Lookup sequence
69
  tree_dt[, split_feature := Lookup(split_feature,
70
                                    seq.int(from = 0, to = parsed_json_model$max_feature_idx),
71
                                    parsed_json_model$feature_names)]
72
73
  
  # Return tree
74
  return(tree_dt)
75
  
76
77
78
}

single.tree.parse <- function(lgb_tree) {
79
80
  
  # Traverse tree function
Yachen Yan's avatar
Yachen Yan committed
81
  pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) {
82
    
Yachen Yan's avatar
Yachen Yan committed
83
84
85
86
87
88
89
90
91
92
93
94
95
    if (is.null(env)) {
      # Setup initial default data.table with default types
      env <- new.env(parent = emptyenv())
      env$single_tree_dt <- data.table::data.table(tree_index = integer(0),
                                                   depth = integer(0),
                                                   split_index = integer(0),
                                                   split_feature = integer(0),
                                                   node_parent = integer(0),
                                                   leaf_index = integer(0),
                                                   leaf_parent = integer(0),
                                                   split_gain = numeric(0),
                                                   threshold = numeric(0),
                                                   decision_type = character(0),
96
                                                   default_left = character(0),
Yachen Yan's avatar
Yachen Yan committed
97
98
99
100
101
102
103
                                                   internal_value = integer(0),
                                                   internal_count = integer(0),
                                                   leaf_value = integer(0),
                                                   leaf_count = integer(0))
      # start tree traversal
      pre_order_traversal(env, tree_node_leaf, current_depth, parent_index)
    } else {
104
      
Yachen Yan's avatar
Yachen Yan committed
105
106
107
108
109
110
111
112
113
114
      # Check if split index is not null in leaf
      if (!is.null(tree_node_leaf$split_index)) {
        
        # update data.table
        env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
                                                             c(tree_node_leaf[c("split_index",
                                                                                "split_feature",
                                                                                "split_gain",
                                                                                "threshold",
                                                                                "decision_type",
115
                                                                                "default_left",
Yachen Yan's avatar
Yachen Yan committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
                                                                                "internal_value",
                                                                                "internal_count")],
                                                               "depth" = current_depth,
                                                               "node_parent" = parent_index)),
                                                    use.names = TRUE,
                                                    fill = TRUE)
        
        # Traverse tree again both left and right
        pre_order_traversal(env,
                            tree_node_leaf$left_child,
                            current_depth = current_depth + 1L,
                            parent_index = tree_node_leaf$split_index)
        pre_order_traversal(env,
                            tree_node_leaf$right_child,
                            current_depth = current_depth + 1L,
                            parent_index = tree_node_leaf$split_index)
        
      } else if (!is.null(tree_node_leaf$leaf_index)) {
        
        # update data.table
        env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
                                                             c(tree_node_leaf[c("leaf_index",
                                                                                "leaf_value",
                                                                                "leaf_count")],
                                                               "depth" = current_depth,
                                                               "leaf_parent" = parent_index)),
                                                    use.names = TRUE,
                                                    fill = TRUE)
        
      }
146
      
147
    }
Yachen Yan's avatar
Yachen Yan committed
148
    return(env$single_tree_dt)
149
  }
150
151
  
  # Traverse structure
Yachen Yan's avatar
Yachen Yan committed
152
  single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
153
154
  
  # Store index
155
  single_tree_dt[, tree_index := lgb_tree$tree_index]
156
157
  
  # Return tree
158
  return(single_tree_dt)
159
  
160
161
162
}

Lookup <- function(key, key_lookup, value_lookup, missing = NA) {
163
164
  
  # Match key by looked up key
165
166
167
  match(key, key_lookup) %>%
    magrittr::extract(value_lookup, .) %>%
    magrittr::inset(. , is.na(.), missing)
168
  
169
}