lgb.interprete.R 6.28 KB
Newer Older
1
2
3
#' @name lgb.interprete
#' @title Compute feature contribution of prediction
#' @description Computes feature contribution components of rawscore prediction.
4
5
#' @param model object of class \code{lgb.Booster}.
#' @param data a matrix object or a dgCMatrix object.
6
#' @param idxset an integer vector of indices of rows needed.
7
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration.
8
#'
9
10
11
#' @return For regression, binary classification and lambdarank model, a \code{list} of \code{data.table}
#'         with the following columns:
#'         \itemize{
12
13
#'             \item{\code{Feature}: Feature names in the model.}
#'             \item{\code{Contribution}: The total contribution of this feature's splits.}
14
15
16
#'         }
#'         For multiclass classification, a \code{list} of \code{data.table} with the Feature column and
#'         Contribution columns to each class.
17
#'
18
#' @examples
19
#' Logit <- function(x) log(x / (1.0 - x))
20
#' data(agaricus.train, package = "lightgbm")
21
22
23
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label)))
24
#' data(agaricus.test, package = "lightgbm")
25
#' test <- agaricus.test
26
#'
27
28
#' params <- list(
#'     objective = "binary"
29
#'     , learning_rate = 0.1
30
31
32
#'     , max_depth = -1L
#'     , min_data_in_leaf = 1L
#'     , min_sum_hessian_in_leaf = 1.0
33
#' )
34
35
36
37
38
#' model <- lgb.train(
#'     params = params
#'     , data = dtrain
#'     , nrounds = 3L
#' )
39
#'
40
#' tree_interpretation <- lgb.interprete(model, test$data, 1L:5L)
41
#'
42
#' @importFrom data.table as.data.table
43
#' @export
44
45
46
47
lgb.interprete <- function(model,
                           data,
                           idxset,
                           num_iteration = NULL) {
48

49
  # Get tree model
50
  tree_dt <- lgb.model.dt.tree(model, num_iteration)
51

52
  # Check number of classes
53
  num_class <- model$.__enclos_env__$private$num_class
54

55
  # Get vector list
56
  tree_interpretation_dt_list <- vector(mode = "list", length = length(idxset))
57

58
  # Get parsed predictions of data
59
60
61
62
63
64
65
66
67
68
69
70
  pred_mat <- t(
    model$predict(
      data[idxset, , drop = FALSE]
      , num_iteration = num_iteration
      , predleaf = TRUE
    )
  )
  leaf_index_dt <- data.table::as.data.table(pred_mat)
  leaf_index_mat_list <- lapply(
    X = leaf_index_dt
    , FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE)
  )
71

72
  # Get list of trees
73
74
  tree_index_mat_list <- lapply(
    X = leaf_index_mat_list
75
76
    , FUN = function(x) {
      matrix(seq_len(length(x)) - 1L, ncol = num_class, byrow = TRUE)
77
78
    }
  )
79

80
  # Sequence over idxset
81
  for (i in seq_along(idxset)) {
82
83
84
85
86
87
    tree_interpretation_dt_list[[i]] <- single.row.interprete(
      tree_dt
      , num_class
      , tree_index_mat_list[[i]]
      , leaf_index_mat_list[[i]]
    )
88
  }
89

90
  # Return interpretation list
91
  return(tree_interpretation_dt_list)
92

93
94
}

95
#' @importFrom data.table data.table
96
97
98
single.tree.interprete <- function(tree_dt,
                                   tree_id,
                                   leaf_id) {
99

100
  # Match tree id
101
  single_tree_dt <- tree_dt[tree_index == tree_id, ]
102

103
  # Get leaves
104
  leaf_dt <- single_tree_dt[leaf_index == leaf_id, .(leaf_index, leaf_parent, leaf_value)]
105

106
  # Get nodes
107
  node_dt <- single_tree_dt[!is.na(split_index), .(split_index, split_feature, node_parent, internal_value)]
108

109
  # Prepare sequences
110
111
  feature_seq <- character(0L)
  value_seq <- numeric(0L)
112

113
  # Get to root from leaf
114
  leaf_to_root <- function(parent_id, current_value) {
115

116
    # Store value
117
    value_seq <<- c(current_value, value_seq)
118

119
    # Check for null parent id
120
    if (!is.na(parent_id)) {
121

122
      # Not null means existing node
123
124
125
      this_node <- node_dt[split_index == parent_id, ]
      feature_seq <<- c(this_node[["split_feature"]], feature_seq)
      leaf_to_root(this_node[["node_parent"]], this_node[["internal_value"]])
126

127
    }
128

129
  }
130

131
  # Perform leaf to root conversion
132
  leaf_to_root(leaf_dt[["leaf_parent"]], leaf_dt[["leaf_value"]])
133

134
  # Return formatted data.table
135
136
137
138
  data.table::data.table(
    Feature = feature_seq
    , Contribution = diff.default(value_seq)
  )
139

140
141
}

142
#' @importFrom data.table := rbindlist setorder
143
144
145
multiple.tree.interprete <- function(tree_dt,
                                     tree_index,
                                     leaf_index) {
146

147
  # Apply each trees
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
  interp_dt <- data.table::rbindlist(
    l = mapply(
      FUN = single.tree.interprete
      , tree_id = tree_index
      , leaf_id = leaf_index
      , MoreArgs = list(
        tree_dt = tree_dt
      )
      , SIMPLIFY = FALSE
      , USE.NAMES = TRUE
    )
    , use.names = TRUE
  )

  interp_dt <- interp_dt[, .(Contribution = sum(Contribution)), by = "Feature"]

  # Sort features in descending order by contribution
  interp_dt[, abs_contribution := abs(Contribution)]
  data.table::setorder(
    x = interp_dt
    , -abs_contribution
  )

  # Drop absolute value of contribution (only needed for sorting)
  interp_dt[, abs_contribution := NULL]

  return(interp_dt)
175

176
177
}

178
#' @importFrom data.table set setnames
179
single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index_mat) {
180

181
  # Prepare vector list
182
  tree_interpretation <- vector(mode = "list", length = num_class)
183

184
  # Loop throughout each class
185
  for (i in seq_len(num_class)) {
186

187
188
    next_interp_dt <- multiple.tree.interprete(
      tree_dt = tree_dt
189
190
      , tree_index = tree_index_mat[, i]
      , leaf_index = leaf_index_mat[, i]
191
192
    )

193
    if (num_class > 1L) {
194
195
196
      data.table::setnames(
        next_interp_dt
        , old = "Contribution"
197
        , new = paste("Class", i - 1L)
198
      )
199
    }
200

201
202
    tree_interpretation[[i]] <- next_interp_dt

203
  }
204

205
  # Check for numbe rof classes larger than 1
206
  if (num_class == 1L) {
207

208
    # First interpretation element
209
    tree_interpretation_dt <- tree_interpretation[[1L]]
210

211
  } else {
212

213
    # Full interpretation elements
214
    tree_interpretation_dt <- Reduce(
215
      f = function(x, y) {
216
217
218
219
        merge(x, y, by = "Feature", all = TRUE)
      }
      , x = tree_interpretation
    )
220

221
    # Loop throughout each tree
222
    for (j in 2L:ncol(tree_interpretation_dt)) {
223

224
225
226
227
      data.table::set(
        tree_interpretation_dt
        , i = which(is.na(tree_interpretation_dt[[j]]))
        , j = j
228
        , value = 0.0
229
      )
230

231
    }
232

233
  }
234

235
  # Return interpretation tree
236
237
  return(tree_interpretation_dt)
}