lgb.model.dt.tree.R 7.63 KB
Newer Older
1
#' Parse a LightGBM model json dump
James Lamb's avatar
James Lamb committed
2
#'
3
#' Parse a LightGBM model json dump into a \code{data.table} structure.
James Lamb's avatar
James Lamb committed
4
#'
5
#' @param model object of class \code{lgb.Booster}
James Lamb's avatar
James Lamb committed
6
#' @param num_iteration number of iterations you want to predict with. NULL or
James Lamb's avatar
James Lamb committed
7
#'                      <= 0 means use best iteration
James Lamb's avatar
James Lamb committed
8
#'
9
10
#' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs.
James Lamb's avatar
James Lamb committed
11
#'
12
#' The columns of the \code{data.table} are:
James Lamb's avatar
James Lamb committed
13
#'
14
15
16
17
#' \itemize{
#'  \item \code{tree_index}: ID of a tree in a model (integer)
#'  \item \code{split_index}: ID of a node in a tree (integer)
#'  \item \code{split_feature}: for a node, it's a feature name (character);
18
#'                              for a leaf, it simply labels it as \code{"NA"}
19
20
21
22
#'  \item \code{node_parent}: ID of the parent node for current node (integer)
#'  \item \code{leaf_index}: ID of a leaf in a tree (integer)
#'  \item \code{leaf_parent}: ID of the parent node for current leaf (integer)
#'  \item \code{split_gain}: Split gain of a node
James Lamb's avatar
James Lamb committed
23
#'  \item \code{threshold}: Splitting threshold value of a node
24
#'  \item \code{decision_type}: Decision type of a node
25
#'  \item \code{default_left}: Determine how to handle NA value, TRUE -> Left, FALSE -> Right
26
27
28
29
30
#'  \item \code{internal_value}: Node value
#'  \item \code{internal_count}: The number of observation collected by a node
#'  \item \code{leaf_value}: Leaf value
#'  \item \code{leaf_count}: The number of observation collected by a leaf
#' }
James Lamb's avatar
James Lamb committed
31
#'
32
#' @examples
33
34
#' \dontrun{
#' library(lightgbm)
James Lamb's avatar
James Lamb committed
35
#'
36
#' data(agaricus.train, package = "lightgbm")
37
38
39
40
41
42
43
44
45
46
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#'
#' params = list(objective = "binary",
#'               learning_rate = 0.01, num_leaves = 63, max_depth = -1,
#'               min_data_in_leaf = 1, min_sum_hessian_in_leaf = 1)
#'               model <- lgb.train(params, dtrain, 20)
#' model <- lgb.train(params, dtrain, 20)
#'
#' tree_dt <- lgb.model.dt.tree(model)
47
#' }
James Lamb's avatar
James Lamb committed
48
#'
49
#' @importFrom magrittr %>%
James Lamb's avatar
James Lamb committed
50
#' @importFrom data.table := data.table rbindlist
James Lamb's avatar
James Lamb committed
51
#' @importFrom jsonlite fromJSON
52
53
#' @export
lgb.model.dt.tree <- function(model, num_iteration = NULL) {
James Lamb's avatar
James Lamb committed
54

55
  # Dump json model first
56
  json_model <- lgb.dump(model, num_iteration = num_iteration)
James Lamb's avatar
James Lamb committed
57

58
  # Parse json model second
59
60
61
62
63
  parsed_json_model <- jsonlite::fromJSON(json_model,
                                          simplifyVector = TRUE,
                                          simplifyDataFrame = FALSE,
                                          simplifyMatrix = FALSE,
                                          flatten = FALSE)
James Lamb's avatar
James Lamb committed
64

65
  # Parse tree model third
66
  tree_list <- lapply(parsed_json_model$tree_info, single.tree.parse)
James Lamb's avatar
James Lamb committed
67

68
  # Combine into single data.table fourth
69
  tree_dt <- data.table::rbindlist(tree_list, use.names = TRUE)
James Lamb's avatar
James Lamb committed
70

71
  # Lookup sequence
72
  tree_dt[, split_feature := Lookup(split_feature,
73
                                    seq.int(from = 0, to = parsed_json_model$max_feature_idx),
74
                                    parsed_json_model$feature_names)]
James Lamb's avatar
James Lamb committed
75

76
  # Return tree
77
  return(tree_dt)
James Lamb's avatar
James Lamb committed
78

79
80
}

James Lamb's avatar
James Lamb committed
81

James Lamb's avatar
James Lamb committed
82
#' @importFrom data.table data.table rbindlist
83
single.tree.parse <- function(lgb_tree) {
James Lamb's avatar
James Lamb committed
84

85
  # Traverse tree function
Yachen Yan's avatar
Yachen Yan committed
86
  pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) {
James Lamb's avatar
James Lamb committed
87

Yachen Yan's avatar
Yachen Yan committed
88
89
90
91
92
93
94
95
96
97
98
99
100
    if (is.null(env)) {
      # Setup initial default data.table with default types
      env <- new.env(parent = emptyenv())
      env$single_tree_dt <- data.table::data.table(tree_index = integer(0),
                                                   depth = integer(0),
                                                   split_index = integer(0),
                                                   split_feature = integer(0),
                                                   node_parent = integer(0),
                                                   leaf_index = integer(0),
                                                   leaf_parent = integer(0),
                                                   split_gain = numeric(0),
                                                   threshold = numeric(0),
                                                   decision_type = character(0),
101
                                                   default_left = character(0),
Yachen Yan's avatar
Yachen Yan committed
102
103
104
105
106
107
108
                                                   internal_value = integer(0),
                                                   internal_count = integer(0),
                                                   leaf_value = integer(0),
                                                   leaf_count = integer(0))
      # start tree traversal
      pre_order_traversal(env, tree_node_leaf, current_depth, parent_index)
    } else {
James Lamb's avatar
James Lamb committed
109

Yachen Yan's avatar
Yachen Yan committed
110
111
      # Check if split index is not null in leaf
      if (!is.null(tree_node_leaf$split_index)) {
James Lamb's avatar
James Lamb committed
112

Yachen Yan's avatar
Yachen Yan committed
113
114
115
116
117
118
119
        # update data.table
        env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
                                                             c(tree_node_leaf[c("split_index",
                                                                                "split_feature",
                                                                                "split_gain",
                                                                                "threshold",
                                                                                "decision_type",
120
                                                                                "default_left",
Yachen Yan's avatar
Yachen Yan committed
121
122
123
124
125
126
                                                                                "internal_value",
                                                                                "internal_count")],
                                                               "depth" = current_depth,
                                                               "node_parent" = parent_index)),
                                                    use.names = TRUE,
                                                    fill = TRUE)
James Lamb's avatar
James Lamb committed
127

Yachen Yan's avatar
Yachen Yan committed
128
129
130
131
132
133
134
135
136
        # Traverse tree again both left and right
        pre_order_traversal(env,
                            tree_node_leaf$left_child,
                            current_depth = current_depth + 1L,
                            parent_index = tree_node_leaf$split_index)
        pre_order_traversal(env,
                            tree_node_leaf$right_child,
                            current_depth = current_depth + 1L,
                            parent_index = tree_node_leaf$split_index)
James Lamb's avatar
James Lamb committed
137

Yachen Yan's avatar
Yachen Yan committed
138
      } else if (!is.null(tree_node_leaf$leaf_index)) {
James Lamb's avatar
James Lamb committed
139

Yachen Yan's avatar
Yachen Yan committed
140
141
142
143
144
145
146
147
148
        # update data.table
        env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
                                                             c(tree_node_leaf[c("leaf_index",
                                                                                "leaf_value",
                                                                                "leaf_count")],
                                                               "depth" = current_depth,
                                                               "leaf_parent" = parent_index)),
                                                    use.names = TRUE,
                                                    fill = TRUE)
James Lamb's avatar
James Lamb committed
149

Yachen Yan's avatar
Yachen Yan committed
150
      }
James Lamb's avatar
James Lamb committed
151

152
    }
Yachen Yan's avatar
Yachen Yan committed
153
    return(env$single_tree_dt)
154
  }
James Lamb's avatar
James Lamb committed
155

156
  # Traverse structure
Yachen Yan's avatar
Yachen Yan committed
157
  single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
James Lamb's avatar
James Lamb committed
158

159
  # Store index
160
  single_tree_dt[, tree_index := lgb_tree$tree_index]
James Lamb's avatar
James Lamb committed
161

162
  # Return tree
163
  return(single_tree_dt)
James Lamb's avatar
James Lamb committed
164

165
166
}

James Lamb's avatar
James Lamb committed
167
#' @importFrom magrittr %>% extract inset
168
Lookup <- function(key, key_lookup, value_lookup, missing = NA) {
James Lamb's avatar
James Lamb committed
169

170
  # Match key by looked up key
171
172
173
  match(key, key_lookup) %>%
    magrittr::extract(value_lookup, .) %>%
    magrittr::inset(. , is.na(.), missing)
James Lamb's avatar
James Lamb committed
174

175
}