lgb.interprete.R 6.17 KB
Newer Older
1
#' Compute feature contribution of prediction
2
#'
3
#' Computes feature contribution components of rawscore prediction.
4
#'
5
6
#' @param model object of class \code{lgb.Booster}.
#' @param data a matrix object or a dgCMatrix object.
7
#' @param idxset an integer vector of indices of rows needed.
8
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration.
9
#'
10
#' @return
11
#'
12
13
14
15
16
17
#' For regression, binary classification and lambdarank model, a \code{list} of \code{data.table} with the following columns:
#' \itemize{
#'   \item \code{Feature} Feature names in the model.
#'   \item \code{Contribution} The total contribution of this feature's splits.
#' }
#' For multiclass classification, a \code{list} of \code{data.table} with the Feature column and Contribution columns to each class.
18
#'
19
20
21
#' @examples
#' Sigmoid <- function(x) 1 / (1 + exp(-x))
#' Logit <- function(x) log(x / (1 - x))
22
#' data(agaricus.train, package = "lightgbm")
23
24
25
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label)))
26
#' data(agaricus.test, package = "lightgbm")
27
#' test <- agaricus.test
28
#'
29
30
31
32
33
34
35
36
#' params <- list(
#'     objective = "binary"
#'     , learning_rate = 0.01
#'     , num_leaves = 63
#'     , max_depth = -1
#'     , min_data_in_leaf = 1
#'     , min_sum_hessian_in_leaf = 1
#' )
37
#' model <- lgb.train(params, dtrain, 10)
38
#'
39
#' tree_interpretation <- lgb.interprete(model, test$data, 1:5)
40
#'
41
#' @importFrom data.table as.data.table
42
#' @export
43
44
45
46
lgb.interprete <- function(model,
                           data,
                           idxset,
                           num_iteration = NULL) {
47

48
  # Get tree model
49
  tree_dt <- lgb.model.dt.tree(model, num_iteration)
50

51
  # Check number of classes
52
  num_class <- model$.__enclos_env__$private$num_class
53

54
  # Get vector list
55
  tree_interpretation_dt_list <- vector(mode = "list", length = length(idxset))
56

57
  # Get parsed predictions of data
58
59
60
61
62
63
64
65
66
67
68
69
  pred_mat <- t(
    model$predict(
      data[idxset, , drop = FALSE]
      , num_iteration = num_iteration
      , predleaf = TRUE
    )
  )
  leaf_index_dt <- data.table::as.data.table(pred_mat)
  leaf_index_mat_list <- lapply(
    X = leaf_index_dt
    , FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE)
  )
70

71
  # Get list of trees
72
73
74
75
76
77
  tree_index_mat_list <- lapply(
    X = leaf_index_mat_list
    , FUN = function(x){
      matrix(seq_len(length(x)) - 1, ncol = num_class, byrow = TRUE)
    }
  )
78

79
  # Sequence over idxset
80
  for (i in seq_along(idxset)) {
81
82
83
84
85
86
    tree_interpretation_dt_list[[i]] <- single.row.interprete(
      tree_dt
      , num_class
      , tree_index_mat_list[[i]]
      , leaf_index_mat_list[[i]]
    )
87
  }
88

89
  # Return interpretation list
90
  return(tree_interpretation_dt_list)
91

92
93
}

94
#' @importFrom data.table data.table
95
96
97
single.tree.interprete <- function(tree_dt,
                                   tree_id,
                                   leaf_id) {
98

99
  # Match tree id
100
  single_tree_dt <- tree_dt[tree_index == tree_id, ]
101

102
  # Get leaves
103
  leaf_dt <- single_tree_dt[leaf_index == leaf_id, .(leaf_index, leaf_parent, leaf_value)]
104

105
  # Get nodes
106
  node_dt <- single_tree_dt[!is.na(split_index), .(split_index, split_feature, node_parent, internal_value)]
107

108
  # Prepare sequences
109
110
  feature_seq <- character(0)
  value_seq <- numeric(0)
111

112
  # Get to root from leaf
113
  leaf_to_root <- function(parent_id, current_value) {
114

115
    # Store value
116
    value_seq <<- c(current_value, value_seq)
117

118
    # Check for null parent id
119
    if (!is.na(parent_id)) {
120

121
      # Not null means existing node
122
123
124
      this_node <- node_dt[split_index == parent_id, ]
      feature_seq <<- c(this_node[["split_feature"]], feature_seq)
      leaf_to_root(this_node[["node_parent"]], this_node[["internal_value"]])
125

126
    }
127

128
  }
129

130
  # Perform leaf to root conversion
131
  leaf_to_root(leaf_dt[["leaf_parent"]], leaf_dt[["leaf_value"]])
132

133
  # Return formatted data.table
134
135
136
137
  data.table::data.table(
    Feature = feature_seq
    , Contribution = diff.default(value_seq)
  )
138

139
140
}

141
#' @importFrom data.table := rbindlist setorder
142
143
144
multiple.tree.interprete <- function(tree_dt,
                                     tree_index,
                                     leaf_index) {
145

146
  # Apply each trees
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
  interp_dt <- data.table::rbindlist(
    l = mapply(
      FUN = single.tree.interprete
      , tree_id = tree_index
      , leaf_id = leaf_index
      , MoreArgs = list(
        tree_dt = tree_dt
      )
      , SIMPLIFY = FALSE
      , USE.NAMES = TRUE
    )
    , use.names = TRUE
  )

  interp_dt <- interp_dt[, .(Contribution = sum(Contribution)), by = "Feature"]

  # Sort features in descending order by contribution
  interp_dt[, abs_contribution := abs(Contribution)]
  data.table::setorder(
    x = interp_dt
    , -abs_contribution
  )

  # Drop absolute value of contribution (only needed for sorting)
  interp_dt[, abs_contribution := NULL]

  return(interp_dt)
174

175
176
}

177
#' @importFrom data.table set setnames
178
single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index_mat) {
179

180
  # Prepare vector list
181
  tree_interpretation <- vector(mode = "list", length = num_class)
182

183
  # Loop throughout each class
184
  for (i in seq_len(num_class)) {
185

186
187
188
189
190
191
192
193
194
195
196
197
    next_interp_dt <- multiple.tree.interprete(
      tree_dt = tree_dt
      , tree_index = tree_index_mat[,i]
      , leaf_index = leaf_index_mat[,i]
    )

    if (num_class > 1){
      data.table::setnames(
        next_interp_dt
        , old = "Contribution"
        , new = paste("Class", i - 1)
      )
198
    }
199

200
201
    tree_interpretation[[i]] <- next_interp_dt

202
  }
203

204
  # Check for numbe rof classes larger than 1
205
  if (num_class == 1) {
206

207
    # First interpretation element
208
    tree_interpretation_dt <- tree_interpretation[[1]]
209

210
  } else {
211

212
    # Full interpretation elements
213
214
215
216
217
218
    tree_interpretation_dt <- Reduce(
      f = function(x, y){
        merge(x, y, by = "Feature", all = TRUE)
      }
      , x = tree_interpretation
    )
219

220
    # Loop throughout each tree
221
    for (j in 2:ncol(tree_interpretation_dt)) {
222

223
224
225
226
227
228
      data.table::set(
        tree_interpretation_dt
        , i = which(is.na(tree_interpretation_dt[[j]]))
        , j = j
        , value = 0
      )
229

230
    }
231

232
  }
233

234
  # Return interpretation tree
235
236
  return(tree_interpretation_dt)
}