lgb.model.dt.tree.R 6.26 KB
Newer Older
1
2
3
#' @name lgb.model.dt.tree
#' @title Parse a LightGBM model json dump
#' @description Parse a LightGBM model json dump into a \code{data.table} structure.
4
5
6
7
8
#' @param model object of class \code{lgb.Booster}.
#' @param num_iteration Number of iterations to include. NULL or <= 0 means use best iteration.
#' @param start_iteration Index (1-based) of the first boosting round to include in the output.
#'        For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#'        means "return information about the fifth, sixth, and seventh trees".
James Lamb's avatar
James Lamb committed
9
10
11
#'
#'        \emph{New in version 4.4.0}
#'
12
13
#' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs.
James Lamb's avatar
James Lamb committed
14
#'
15
#' The columns of the \code{data.table} are:
James Lamb's avatar
James Lamb committed
16
#'
17
#' \itemize{
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#'  \item{\code{tree_index}: ID of a tree in a model (integer)}
#'  \item{\code{split_index}: ID of a node in a tree (integer)}
#'  \item{\code{split_feature}: for a node, it's a feature name (character);
#'                              for a leaf, it simply labels it as \code{"NA"}}
#'  \item{\code{node_parent}: ID of the parent node for current node (integer)}
#'  \item{\code{leaf_index}: ID of a leaf in a tree (integer)}
#'  \item{\code{leaf_parent}: ID of the parent node for current leaf (integer)}
#'  \item{\code{split_gain}: Split gain of a node}
#'  \item{\code{threshold}: Splitting threshold value of a node}
#'  \item{\code{decision_type}: Decision type of a node}
#'  \item{\code{default_left}: Determine how to handle NA value, TRUE -> Left, FALSE -> Right}
#'  \item{\code{internal_value}: Node value}
#'  \item{\code{internal_count}: The number of observation collected by a node}
#'  \item{\code{leaf_value}: Leaf value}
#'  \item{\code{leaf_count}: The number of observation collected by a leaf}
33
#' }
James Lamb's avatar
James Lamb committed
34
#'
35
#' @examples
36
#' \donttest{
37
38
#' \dontshow{setLGBMthreads(2L)}
#' \dontshow{data.table::setDTthreads(1L)}
39
#' data(agaricus.train, package = "lightgbm")
40
41
42
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#'
43
44
45
#' params <- list(
#'   objective = "binary"
#'   , learning_rate = 0.01
46
47
48
49
#'   , num_leaves = 63L
#'   , max_depth = -1L
#'   , min_data_in_leaf = 1L
#'   , min_sum_hessian_in_leaf = 1.0
50
#'   , num_threads = 2L
51
#' )
52
#' model <- lgb.train(params, dtrain, 10L)
53
54
#'
#' tree_dt <- lgb.model.dt.tree(model)
55
#' }
56
#' @importFrom data.table := rbindlist
James Lamb's avatar
James Lamb committed
57
#' @importFrom jsonlite fromJSON
58
#' @export
59
60
61
lgb.model.dt.tree <- function(
    model, num_iteration = NULL, start_iteration = 1L
  ) {
James Lamb's avatar
James Lamb committed
62

63
64
65
66
67
  json_model <- lgb.dump(
    booster = model
    , num_iteration = num_iteration
    , start_iteration = start_iteration
  )
James Lamb's avatar
James Lamb committed
68

69
  parsed_json_model <- jsonlite::fromJSON(
70
    txt = json_model
71
72
73
74
75
    , simplifyVector = TRUE
    , simplifyDataFrame = FALSE
    , simplifyMatrix = FALSE
    , flatten = FALSE
  )
James Lamb's avatar
James Lamb committed
76

77
  # Parse tree model
78
79
80
81
  tree_list <- lapply(
    X = parsed_json_model$tree_info
    , FUN = .single_tree_parse
  )
James Lamb's avatar
James Lamb committed
82

83
  # Combine into single data.table
84
  tree_dt <- data.table::rbindlist(l = tree_list, use.names = TRUE)
James Lamb's avatar
James Lamb committed
85

86
87
88
89
  # Substitute feature index with the actual feature name

  # Since the index comes from C++ (which is 0-indexed), be sure
  # to add 1 (e.g. index 28 means the 29th feature in feature_names)
90
  split_feature_indx <- tree_dt[, split_feature] + 1L
91
92
93
94
95

  # Get corresponding feature names. Positions in split_feature_indx
  # which are NA will result in an NA feature name
  feature_names <- parsed_json_model$feature_names[split_feature_indx]
  tree_dt[, split_feature := feature_names]
James Lamb's avatar
James Lamb committed
96

97
98
99
  return(tree_dt)
}

James Lamb's avatar
James Lamb committed
100

101
#' @importFrom data.table := data.table rbindlist
102
.single_tree_parse <- function(lgb_tree) {
103
104
105
106
107
108
109
110
111
112
  tree_info_cols <- c(
    "split_index"
    , "split_feature"
    , "split_gain"
    , "threshold"
    , "decision_type"
    , "default_left"
    , "internal_value"
    , "internal_count"
  )
James Lamb's avatar
James Lamb committed
113

114
  # Traverse tree function
Yachen Yan's avatar
Yachen Yan committed
115
  pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) {
James Lamb's avatar
James Lamb committed
116

Yachen Yan's avatar
Yachen Yan committed
117
118
119
    if (is.null(env)) {
      # Setup initial default data.table with default types
      env <- new.env(parent = emptyenv())
120
121
      env$single_tree_dt <- list()
      env$single_tree_dt[[1L]] <- data.table::data.table(
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        tree_index = integer(0L)
        , depth = integer(0L)
        , split_index = integer(0L)
        , split_feature = integer(0L)
        , node_parent = integer(0L)
        , leaf_index = integer(0L)
        , leaf_parent = integer(0L)
        , split_gain = numeric(0L)
        , threshold = numeric(0L)
        , decision_type = character(0L)
        , default_left = character(0L)
        , internal_value = integer(0L)
        , internal_count = integer(0L)
        , leaf_value = integer(0L)
        , leaf_count = integer(0L)
137
      )
Yachen Yan's avatar
Yachen Yan committed
138
      # start tree traversal
139
140
141
142
143
144
      pre_order_traversal(
        env = env
        , tree_node_leaf = tree_node_leaf
        , current_depth = current_depth
        , parent_index = parent_index
      )
Yachen Yan's avatar
Yachen Yan committed
145
    } else {
James Lamb's avatar
James Lamb committed
146

Yachen Yan's avatar
Yachen Yan committed
147
148
      # Check if split index is not null in leaf
      if (!is.null(tree_node_leaf$split_index)) {
James Lamb's avatar
James Lamb committed
149

Yachen Yan's avatar
Yachen Yan committed
150
        # update data.table
151
152
153
154
        env$single_tree_dt[[length(env$single_tree_dt) + 1L]] <- c(
          tree_node_leaf[tree_info_cols]
          , list("depth" = current_depth, "node_parent" = parent_index)
        )
James Lamb's avatar
James Lamb committed
155

Yachen Yan's avatar
Yachen Yan committed
156
        # Traverse tree again both left and right
157
        pre_order_traversal(
158
159
          env = env
          , tree_node_leaf = tree_node_leaf$left_child
160
161
162
163
          , current_depth = current_depth + 1L
          , parent_index = tree_node_leaf$split_index
        )
        pre_order_traversal(
164
165
          env = env
          , tree_node_leaf = tree_node_leaf$right_child
166
167
168
          , current_depth = current_depth + 1L
          , parent_index = tree_node_leaf$split_index
        )
Yachen Yan's avatar
Yachen Yan committed
169
      } else if (!is.null(tree_node_leaf$leaf_index)) {
James Lamb's avatar
James Lamb committed
170

171
172
173
174
175
        # update list
        env$single_tree_dt[[length(env$single_tree_dt) + 1L]] <- c(
          tree_node_leaf[c("leaf_index", "leaf_value", "leaf_count")]
          , list("depth" = current_depth, "leaf_parent" = parent_index)
        )
Yachen Yan's avatar
Yachen Yan committed
176
      }
177
    }
Yachen Yan's avatar
Yachen Yan committed
178
    return(env$single_tree_dt)
179
  }
James Lamb's avatar
James Lamb committed
180

181
182
183
184
185
186
  # Traverse structure and rowbind everything
  single_tree_dt <- data.table::rbindlist(
    pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
    , use.names = TRUE
    , fill = TRUE
  )
James Lamb's avatar
James Lamb committed
187

188
  # Store index
189
  single_tree_dt[, tree_index := lgb_tree$tree_index]
James Lamb's avatar
James Lamb committed
190

191
192
  return(single_tree_dt)
}