lgb.interprete.R 6.39 KB
Newer Older
1
2
3
#' @name lgb.interprete
#' @title Compute feature contribution of prediction
#' @description Computes feature contribution components of rawscore prediction.
4
5
#' @param model object of class \code{lgb.Booster}.
#' @param data a matrix object or a dgCMatrix object.
6
#' @param idxset an integer vector of indices of rows needed.
7
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration.
8
#'
9
10
11
#' @return For regression, binary classification and lambdarank model, a \code{list} of \code{data.table}
#'         with the following columns:
#'         \itemize{
12
13
#'             \item{\code{Feature}: Feature names in the model.}
#'             \item{\code{Contribution}: The total contribution of this feature's splits.}
14
15
16
#'         }
#'         For multiclass classification, a \code{list} of \code{data.table} with the Feature column and
#'         Contribution columns to each class.
17
#'
18
#' @examples
19
#' \donttest{
20
21
#' \dontshow{setLGBMthreads(2L)}
#' \dontshow{data.table::setDTthreads(1L)}
22
#' Logit <- function(x) log(x / (1.0 - x))
23
#' data(agaricus.train, package = "lightgbm")
24
25
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
26
27
28
29
30
#' set_field(
#'   dataset = dtrain
#'   , field_name = "init_score"
#'   , data = rep(Logit(mean(train$label)), length(train$label))
#' )
31
#' data(agaricus.test, package = "lightgbm")
32
#' test <- agaricus.test
33
#'
34
35
#' params <- list(
#'     objective = "binary"
36
#'     , learning_rate = 0.1
37
38
39
#'     , max_depth = -1L
#'     , min_data_in_leaf = 1L
#'     , min_sum_hessian_in_leaf = 1.0
40
#'     , num_threads = 2L
41
#' )
42
43
44
45
46
#' model <- lgb.train(
#'     params = params
#'     , data = dtrain
#'     , nrounds = 3L
#' )
47
#'
48
#' tree_interpretation <- lgb.interprete(model, test$data, 1L:5L)
49
#' }
50
#' @importFrom data.table as.data.table
51
#' @export
52
53
54
55
lgb.interprete <- function(model,
                           data,
                           idxset,
                           num_iteration = NULL) {
56

57
  # Get tree model
58
  tree_dt <- lgb.model.dt.tree(model = model, num_iteration = num_iteration)
59

60
  # Check number of classes
61
  num_class <- model$.__enclos_env__$private$num_class
62

63
  # Get vector list
64
  tree_interpretation_dt_list <- vector(mode = "list", length = length(idxset))
65

66
  # Get parsed predictions of data
67
68
  pred_mat <- t(
    model$predict(
69
      data = data[idxset, , drop = FALSE]
70
71
72
73
      , num_iteration = num_iteration
      , predleaf = TRUE
    )
  )
74
  leaf_index_dt <- data.table::as.data.table(x = pred_mat)
75
76
  leaf_index_mat_list <- lapply(
    X = leaf_index_dt
77
78
79
    , FUN = matrix
    , ncol = num_class
    , byrow = TRUE
80
  )
81

82
  # Get list of trees
83
84
  tree_index_mat_list <- lapply(
    X = leaf_index_mat_list
85
86
    , FUN = function(x) {
      matrix(seq_len(length(x)) - 1L, ncol = num_class, byrow = TRUE)
87
88
    }
  )
89

90
  for (i in seq_along(idxset)) {
91
    tree_interpretation_dt_list[[i]] <- .single_row_interprete(
92
93
94
95
      tree_dt = tree_dt
      , num_class = num_class
      , tree_index_mat = tree_index_mat_list[[i]]
      , leaf_index_mat = leaf_index_mat_list[[i]]
96
    )
97
  }
98

99
  return(tree_interpretation_dt_list)
100

101
102
}

103
#' @importFrom data.table data.table
104
105
106
single.tree.interprete <- function(tree_dt,
                                   tree_id,
                                   leaf_id) {
107

108
  # Match tree id
109
  single_tree_dt <- tree_dt[tree_index == tree_id, ]
110

111
  # Get leaves
112
  leaf_dt <- single_tree_dt[leaf_index == leaf_id, .(leaf_index, leaf_parent, leaf_value)]
113

114
  # Get nodes
115
  node_dt <- single_tree_dt[!is.na(split_index), .(split_index, split_feature, node_parent, internal_value)]
116

117
  # Prepare sequences
118
119
  feature_seq <- character(0L)
  value_seq <- numeric(0L)
120

121
  # Get to root from leaf
122
  leaf_to_root <- function(parent_id, current_value) {
123

124
    value_seq <<- c(current_value, value_seq)
125

126
    if (!is.na(parent_id)) {
127

128
      # Not null means existing node
129
130
      this_node <- node_dt[split_index == parent_id, ]
      feature_seq <<- c(this_node[["split_feature"]], feature_seq)
131
132
133
134
      leaf_to_root(
        parent_id = this_node[["node_parent"]]
        , current_value = this_node[["internal_value"]]
      )
135

136
    }
137

138
  }
139

140
  # Perform leaf to root conversion
141
142
143
144
  leaf_to_root(
    parent_id = leaf_dt[["leaf_parent"]]
    , current_value = leaf_dt[["leaf_value"]]
  )
145

146
147
148
149
150
  return(
    data.table::data.table(
      Feature = feature_seq
      , Contribution = diff.default(value_seq)
    )
151
  )
152

153
154
}

155
#' @importFrom data.table := rbindlist setorder
156
.multiple_tree_interprete <- function(tree_dt,
157
158
                                     tree_index,
                                     leaf_index) {
159

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
  interp_dt <- data.table::rbindlist(
    l = mapply(
      FUN = single.tree.interprete
      , tree_id = tree_index
      , leaf_id = leaf_index
      , MoreArgs = list(
        tree_dt = tree_dt
      )
      , SIMPLIFY = FALSE
      , USE.NAMES = TRUE
    )
    , use.names = TRUE
  )

  interp_dt <- interp_dt[, .(Contribution = sum(Contribution)), by = "Feature"]

  # Sort features in descending order by contribution
  interp_dt[, abs_contribution := abs(Contribution)]
  data.table::setorder(
    x = interp_dt
    , -abs_contribution
  )

  # Drop absolute value of contribution (only needed for sorting)
  interp_dt[, abs_contribution := NULL]

  return(interp_dt)
187

188
189
}

190
#' @importFrom data.table set setnames
191
.single_row_interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index_mat) {
192

193
  # Prepare vector list
194
  tree_interpretation <- vector(mode = "list", length = num_class)
195

196
  # Loop throughout each class
197
  for (i in seq_len(num_class)) {
198

199
    next_interp_dt <- .multiple_tree_interprete(
200
      tree_dt = tree_dt
201
202
      , tree_index = tree_index_mat[, i]
      , leaf_index = leaf_index_mat[, i]
203
204
    )

205
    if (num_class > 1L) {
206
      data.table::setnames(
207
        x = next_interp_dt
208
        , old = "Contribution"
209
        , new = paste("Class", i - 1L)
210
      )
211
    }
212

213
214
    tree_interpretation[[i]] <- next_interp_dt

215
  }
216

217
  if (num_class == 1L) {
218

219
    tree_interpretation_dt <- tree_interpretation[[1L]]
220

221
  } else {
222

223
    # Full interpretation elements
224
    tree_interpretation_dt <- Reduce(
225
      f = function(x, y) {
226
227
228
229
        merge(x, y, by = "Feature", all = TRUE)
      }
      , x = tree_interpretation
    )
230

231
    # Loop throughout each tree
232
    for (j in 2L:ncol(tree_interpretation_dt)) {
233

234
      data.table::set(
235
        x = tree_interpretation_dt
236
237
        , i = which(is.na(tree_interpretation_dt[[j]]))
        , j = j
238
        , value = 0.0
239
      )
240

241
    }
242

243
  }
244

245
246
  return(tree_interpretation_dt)
}