lgb.model.dt.tree.R 7.58 KB
Newer Older
1
#' Parse a LightGBM model json dump
2
#' 
3
#' Parse a LightGBM model json dump into a \code{data.table} structure.
4
#' 
5
#' @param model object of class \code{lgb.Booster}
James Lamb's avatar
James Lamb committed
6
7
#' @param num_iteration number of iterations you want to predict with. NULL or 
#'                      <= 0 means use best iteration
8
#' 
9
10
#' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs.
11
#' 
12
#' The columns of the \code{data.table} are:
13
#' 
14
15
16
17
#' \itemize{
#'  \item \code{tree_index}: ID of a tree in a model (integer)
#'  \item \code{split_index}: ID of a node in a tree (integer)
#'  \item \code{split_feature}: for a node, it's a feature name (character);
18
#'                              for a leaf, it simply labels it as \code{"NA"}
19
20
21
22
23
#'  \item \code{node_parent}: ID of the parent node for current node (integer)
#'  \item \code{leaf_index}: ID of a leaf in a tree (integer)
#'  \item \code{leaf_parent}: ID of the parent node for current leaf (integer)
#'  \item \code{split_gain}: Split gain of a node
#'  \item \code{threshold}: Spliting threshold value of a node
24
#'  \item \code{decision_type}: Decision type of a node
25
#'  \item \code{default_left}: Determine how to handle NA value, TRUE -> Left, FALSE -> Right
26
27
28
29
30
#'  \item \code{internal_value}: Node value
#'  \item \code{internal_count}: The number of observation collected by a node
#'  \item \code{leaf_value}: Leaf value
#'  \item \code{leaf_count}: The number of observation collected by a leaf
#' }
31
#' 
32
#' @examples
33
34
35
36
#' \dontrun{
#' library(lightgbm)
#' 
#' data(agaricus.train, package = "lightgbm")
37
38
39
40
41
42
43
44
45
46
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#'
#' params = list(objective = "binary",
#'               learning_rate = 0.01, num_leaves = 63, max_depth = -1,
#'               min_data_in_leaf = 1, min_sum_hessian_in_leaf = 1)
#'               model <- lgb.train(params, dtrain, 20)
#' model <- lgb.train(params, dtrain, 20)
#'
#' tree_dt <- lgb.model.dt.tree(model)
47
48
#' }
#' 
49
50
51
52
#' @importFrom magrittr %>%
#' @importFrom data.table :=
#' @export
lgb.model.dt.tree <- function(model, num_iteration = NULL) {
53
54
  
  # Dump json model first
55
  json_model <- lgb.dump(model, num_iteration = num_iteration)
56
57
  
  # Parse json model second
58
59
60
61
62
  parsed_json_model <- jsonlite::fromJSON(json_model,
                                          simplifyVector = TRUE,
                                          simplifyDataFrame = FALSE,
                                          simplifyMatrix = FALSE,
                                          flatten = FALSE)
63
64
  
  # Parse tree model third
65
  tree_list <- lapply(parsed_json_model$tree_info, single.tree.parse)
66
67
  
  # Combine into single data.table fourth
68
  tree_dt <- data.table::rbindlist(tree_list, use.names = TRUE)
69
70
  
  # Lookup sequence
71
  tree_dt[, split_feature := Lookup(split_feature,
72
                                    seq.int(from = 0, to = parsed_json_model$max_feature_idx),
73
                                    parsed_json_model$feature_names)]
74
75
  
  # Return tree
76
  return(tree_dt)
77
  
78
79
80
}

single.tree.parse <- function(lgb_tree) {
81
82
  
  # Traverse tree function
Yachen Yan's avatar
Yachen Yan committed
83
  pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) {
84
    
Yachen Yan's avatar
Yachen Yan committed
85
86
87
88
89
90
91
92
93
94
95
96
97
    if (is.null(env)) {
      # Setup initial default data.table with default types
      env <- new.env(parent = emptyenv())
      env$single_tree_dt <- data.table::data.table(tree_index = integer(0),
                                                   depth = integer(0),
                                                   split_index = integer(0),
                                                   split_feature = integer(0),
                                                   node_parent = integer(0),
                                                   leaf_index = integer(0),
                                                   leaf_parent = integer(0),
                                                   split_gain = numeric(0),
                                                   threshold = numeric(0),
                                                   decision_type = character(0),
98
                                                   default_left = character(0),
Yachen Yan's avatar
Yachen Yan committed
99
100
101
102
103
104
105
                                                   internal_value = integer(0),
                                                   internal_count = integer(0),
                                                   leaf_value = integer(0),
                                                   leaf_count = integer(0))
      # start tree traversal
      pre_order_traversal(env, tree_node_leaf, current_depth, parent_index)
    } else {
106
      
Yachen Yan's avatar
Yachen Yan committed
107
108
109
110
111
112
113
114
115
116
      # Check if split index is not null in leaf
      if (!is.null(tree_node_leaf$split_index)) {
        
        # update data.table
        env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
                                                             c(tree_node_leaf[c("split_index",
                                                                                "split_feature",
                                                                                "split_gain",
                                                                                "threshold",
                                                                                "decision_type",
117
                                                                                "default_left",
Yachen Yan's avatar
Yachen Yan committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
                                                                                "internal_value",
                                                                                "internal_count")],
                                                               "depth" = current_depth,
                                                               "node_parent" = parent_index)),
                                                    use.names = TRUE,
                                                    fill = TRUE)
        
        # Traverse tree again both left and right
        pre_order_traversal(env,
                            tree_node_leaf$left_child,
                            current_depth = current_depth + 1L,
                            parent_index = tree_node_leaf$split_index)
        pre_order_traversal(env,
                            tree_node_leaf$right_child,
                            current_depth = current_depth + 1L,
                            parent_index = tree_node_leaf$split_index)
        
      } else if (!is.null(tree_node_leaf$leaf_index)) {
        
        # update data.table
        env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
                                                             c(tree_node_leaf[c("leaf_index",
                                                                                "leaf_value",
                                                                                "leaf_count")],
                                                               "depth" = current_depth,
                                                               "leaf_parent" = parent_index)),
                                                    use.names = TRUE,
                                                    fill = TRUE)
        
      }
148
      
149
    }
Yachen Yan's avatar
Yachen Yan committed
150
    return(env$single_tree_dt)
151
  }
152
153
  
  # Traverse structure
Yachen Yan's avatar
Yachen Yan committed
154
  single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
155
156
  
  # Store index
157
  single_tree_dt[, tree_index := lgb_tree$tree_index]
158
159
  
  # Return tree
160
  return(single_tree_dt)
161
  
162
163
164
}

Lookup <- function(key, key_lookup, value_lookup, missing = NA) {
165
166
  
  # Match key by looked up key
167
168
169
  match(key, key_lookup) %>%
    magrittr::extract(value_lookup, .) %>%
    magrittr::inset(. , is.na(.), missing)
170
  
171
}