lgb.model.dt.tree.R 7.17 KB
Newer Older
1
#' Parse a LightGBM model json dump
2
#' 
3
#' Parse a LightGBM model json dump into a \code{data.table} structure.
4
#' 
5
#' @param model object of class \code{lgb.Booster}
6
#' 
7
8
#' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs.
9
#' 
10
#' The columns of the \code{data.table} are:
11
#' 
12
13
14
15
#' \itemize{
#'  \item \code{tree_index}: ID of a tree in a model (integer)
#'  \item \code{split_index}: ID of a node in a tree (integer)
#'  \item \code{split_feature}: for a node, it's a feature name (character);
16
#'                              for a leaf, it simply labels it as \code{"NA"}
17
18
19
20
21
#'  \item \code{node_parent}: ID of the parent node for current node (integer)
#'  \item \code{leaf_index}: ID of a leaf in a tree (integer)
#'  \item \code{leaf_parent}: ID of the parent node for current leaf (integer)
#'  \item \code{split_gain}: Split gain of a node
#'  \item \code{threshold}: Spliting threshold value of a node
22
#'  \item \code{decision_type}: Decision type of a node
23
24
25
26
27
#'  \item \code{internal_value}: Node value
#'  \item \code{internal_count}: The number of observation collected by a node
#'  \item \code{leaf_value}: Leaf value
#'  \item \code{leaf_count}: The number of observation collected by a leaf
#' }
28
#' 
29
#' @examples
30
31
32
33
#' \dontrun{
#' library(lightgbm)
#' 
#' data(agaricus.train, package = "lightgbm")
34
35
36
37
38
39
40
41
42
43
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#'
#' params = list(objective = "binary",
#'               learning_rate = 0.01, num_leaves = 63, max_depth = -1,
#'               min_data_in_leaf = 1, min_sum_hessian_in_leaf = 1)
#'               model <- lgb.train(params, dtrain, 20)
#' model <- lgb.train(params, dtrain, 20)
#'
#' tree_dt <- lgb.model.dt.tree(model)
44
45
#' }
#' 
46
47
48
49
#' @importFrom magrittr %>%
#' @importFrom data.table :=
#' @export
lgb.model.dt.tree <- function(model, num_iteration = NULL) {
50
51
  
  # Dump json model first
52
  json_model <- lgb.dump(model, num_iteration = num_iteration)
53
54
  
  # Parse json model second
55
56
57
58
59
  parsed_json_model <- jsonlite::fromJSON(json_model,
                                          simplifyVector = TRUE,
                                          simplifyDataFrame = FALSE,
                                          simplifyMatrix = FALSE,
                                          flatten = FALSE)
60
61
  
  # Parse tree model third
62
  tree_list <- lapply(parsed_json_model$tree_info, single.tree.parse)
63
64
  
  # Combine into single data.table fourth
65
  tree_dt <- data.table::rbindlist(tree_list, use.names = TRUE)
66
67
  
  # Lookup sequence
68
69
70
  tree_dt[, split_feature := Lookup(split_feature,
                                    seq(0, parsed_json_model$max_feature_idx, by = 1),
                                    parsed_json_model$feature_names)]
71
72
  
  # Return tree
73
  return(tree_dt)
74
  
75
76
77
}

single.tree.parse <- function(lgb_tree) {
78
79
  
  # Traverse tree function
Yachen Yan's avatar
Yachen Yan committed
80
  pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) {
81
    
Yachen Yan's avatar
Yachen Yan committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    if (is.null(env)) {
      # Setup initial default data.table with default types
      env <- new.env(parent = emptyenv())
      env$single_tree_dt <- data.table::data.table(tree_index = integer(0),
                                                   depth = integer(0),
                                                   split_index = integer(0),
                                                   split_feature = integer(0),
                                                   node_parent = integer(0),
                                                   leaf_index = integer(0),
                                                   leaf_parent = integer(0),
                                                   split_gain = numeric(0),
                                                   threshold = numeric(0),
                                                   decision_type = character(0),
                                                   internal_value = integer(0),
                                                   internal_count = integer(0),
                                                   leaf_value = integer(0),
                                                   leaf_count = integer(0))
      # start tree traversal
      pre_order_traversal(env, tree_node_leaf, current_depth, parent_index)
    } else {
102
      
Yachen Yan's avatar
Yachen Yan committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
      # Check if split index is not null in leaf
      if (!is.null(tree_node_leaf$split_index)) {
        
        # update data.table
        env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
                                                             c(tree_node_leaf[c("split_index",
                                                                                "split_feature",
                                                                                "split_gain",
                                                                                "threshold",
                                                                                "decision_type",
                                                                                "internal_value",
                                                                                "internal_count")],
                                                               "depth" = current_depth,
                                                               "node_parent" = parent_index)),
                                                    use.names = TRUE,
                                                    fill = TRUE)
        
        # Traverse tree again both left and right
        pre_order_traversal(env,
                            tree_node_leaf$left_child,
                            current_depth = current_depth + 1L,
                            parent_index = tree_node_leaf$split_index)
        pre_order_traversal(env,
                            tree_node_leaf$right_child,
                            current_depth = current_depth + 1L,
                            parent_index = tree_node_leaf$split_index)
        
      } else if (!is.null(tree_node_leaf$leaf_index)) {
        
        # update data.table
        env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
                                                             c(tree_node_leaf[c("leaf_index",
                                                                                "leaf_value",
                                                                                "leaf_count")],
                                                               "depth" = current_depth,
                                                               "leaf_parent" = parent_index)),
                                                    use.names = TRUE,
                                                    fill = TRUE)
        
      }
143
      
144
    }
Yachen Yan's avatar
Yachen Yan committed
145
    return(env$single_tree_dt)
146
  }
147
148
  
  # Traverse structure
Yachen Yan's avatar
Yachen Yan committed
149
  single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
150
151
  
  # Store index
152
  single_tree_dt[, tree_index := lgb_tree$tree_index]
153
154
  
  # Return tree
155
  return(single_tree_dt)
156
  
157
158
159
}

Lookup <- function(key, key_lookup, value_lookup, missing = NA) {
160
161
  
  # Match key by looked up key
162
163
164
  match(key, key_lookup) %>%
    magrittr::extract(value_lookup, .) %>%
    magrittr::inset(. , is.na(.), missing)
165
  
166
}