lgb.interprete.R 6.41 KB
Newer Older
1
2
3
#' @name lgb.interprete
#' @title Compute feature contribution of prediction
#' @description Computes feature contribution components of rawscore prediction.
4
5
#' @param model object of class \code{lgb.Booster}.
#' @param data a matrix object or a dgCMatrix object.
6
#' @param idxset an integer vector of indices of rows needed.
7
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration.
8
#'
9
10
11
#' @return For regression, binary classification and lambdarank model, a \code{list} of \code{data.table}
#'         with the following columns:
#'         \itemize{
12
13
#'             \item{\code{Feature}: Feature names in the model.}
#'             \item{\code{Contribution}: The total contribution of this feature's splits.}
14
15
16
#'         }
#'         For multiclass classification, a \code{list} of \code{data.table} with the Feature column and
#'         Contribution columns to each class.
17
#'
18
#' @examples
19
#' \donttest{
20
#' Logit <- function(x) log(x / (1.0 - x))
21
#' data(agaricus.train, package = "lightgbm")
22
23
24
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label)))
25
#' data(agaricus.test, package = "lightgbm")
26
#' test <- agaricus.test
27
#'
28
29
#' params <- list(
#'     objective = "binary"
30
#'     , learning_rate = 0.1
31
32
33
#'     , max_depth = -1L
#'     , min_data_in_leaf = 1L
#'     , min_sum_hessian_in_leaf = 1.0
34
#' )
35
36
37
38
39
#' model <- lgb.train(
#'     params = params
#'     , data = dtrain
#'     , nrounds = 3L
#' )
40
#'
41
#' tree_interpretation <- lgb.interprete(model, test$data, 1L:5L)
42
#' }
43
#' @importFrom data.table as.data.table
44
#' @export
45
46
47
48
lgb.interprete <- function(model,
                           data,
                           idxset,
                           num_iteration = NULL) {
49

50
  # Get tree model
51
  tree_dt <- lgb.model.dt.tree(model = model, num_iteration = num_iteration)
52

53
  # Check number of classes
54
  num_class <- model$.__enclos_env__$private$num_class
55

56
  # Get vector list
57
  tree_interpretation_dt_list <- vector(mode = "list", length = length(idxset))
58

59
  # Get parsed predictions of data
60
61
  pred_mat <- t(
    model$predict(
62
      data = data[idxset, , drop = FALSE]
63
64
65
66
      , num_iteration = num_iteration
      , predleaf = TRUE
    )
  )
67
  leaf_index_dt <- data.table::as.data.table(x = pred_mat)
68
69
70
71
  leaf_index_mat_list <- lapply(
    X = leaf_index_dt
    , FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE)
  )
72

73
  # Get list of trees
74
75
  tree_index_mat_list <- lapply(
    X = leaf_index_mat_list
76
77
    , FUN = function(x) {
      matrix(seq_len(length(x)) - 1L, ncol = num_class, byrow = TRUE)
78
79
    }
  )
80

81
  # Sequence over idxset
82
  for (i in seq_along(idxset)) {
83
    tree_interpretation_dt_list[[i]] <- single.row.interprete(
84
85
86
87
      tree_dt = tree_dt
      , num_class = num_class
      , tree_index_mat = tree_index_mat_list[[i]]
      , leaf_index_mat = leaf_index_mat_list[[i]]
88
    )
89
  }
90

91
  return(tree_interpretation_dt_list)
92

93
94
}

95
#' @importFrom data.table data.table
96
97
98
single.tree.interprete <- function(tree_dt,
                                   tree_id,
                                   leaf_id) {
99

100
  # Match tree id
101
  single_tree_dt <- tree_dt[tree_index == tree_id, ]
102

103
  # Get leaves
104
  leaf_dt <- single_tree_dt[leaf_index == leaf_id, .(leaf_index, leaf_parent, leaf_value)]
105

106
  # Get nodes
107
  node_dt <- single_tree_dt[!is.na(split_index), .(split_index, split_feature, node_parent, internal_value)]
108

109
  # Prepare sequences
110
111
  feature_seq <- character(0L)
  value_seq <- numeric(0L)
112

113
  # Get to root from leaf
114
  leaf_to_root <- function(parent_id, current_value) {
115

116
    # Store value
117
    value_seq <<- c(current_value, value_seq)
118

119
    # Check for null parent id
120
    if (!is.na(parent_id)) {
121

122
      # Not null means existing node
123
124
      this_node <- node_dt[split_index == parent_id, ]
      feature_seq <<- c(this_node[["split_feature"]], feature_seq)
125
126
127
128
      leaf_to_root(
        parent_id = this_node[["node_parent"]]
        , current_value = this_node[["internal_value"]]
      )
129

130
    }
131

132
  }
133

134
  # Perform leaf to root conversion
135
136
137
138
  leaf_to_root(
    parent_id = leaf_dt[["leaf_parent"]]
    , current_value = leaf_dt[["leaf_value"]]
  )
139

140
141
142
143
144
  return(
    data.table::data.table(
      Feature = feature_seq
      , Contribution = diff.default(value_seq)
    )
145
  )
146

147
148
}

149
#' @importFrom data.table := rbindlist setorder
150
151
152
multiple.tree.interprete <- function(tree_dt,
                                     tree_index,
                                     leaf_index) {
153

154
  # Apply each trees
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
  interp_dt <- data.table::rbindlist(
    l = mapply(
      FUN = single.tree.interprete
      , tree_id = tree_index
      , leaf_id = leaf_index
      , MoreArgs = list(
        tree_dt = tree_dt
      )
      , SIMPLIFY = FALSE
      , USE.NAMES = TRUE
    )
    , use.names = TRUE
  )

  interp_dt <- interp_dt[, .(Contribution = sum(Contribution)), by = "Feature"]

  # Sort features in descending order by contribution
  interp_dt[, abs_contribution := abs(Contribution)]
  data.table::setorder(
    x = interp_dt
    , -abs_contribution
  )

  # Drop absolute value of contribution (only needed for sorting)
  interp_dt[, abs_contribution := NULL]

  return(interp_dt)
182

183
184
}

185
#' @importFrom data.table set setnames
186
single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index_mat) {
187

188
  # Prepare vector list
189
  tree_interpretation <- vector(mode = "list", length = num_class)
190

191
  # Loop throughout each class
192
  for (i in seq_len(num_class)) {
193

194
195
    next_interp_dt <- multiple.tree.interprete(
      tree_dt = tree_dt
196
197
      , tree_index = tree_index_mat[, i]
      , leaf_index = leaf_index_mat[, i]
198
199
    )

200
    if (num_class > 1L) {
201
      data.table::setnames(
202
        x = next_interp_dt
203
        , old = "Contribution"
204
        , new = paste("Class", i - 1L)
205
      )
206
    }
207

208
209
    tree_interpretation[[i]] <- next_interp_dt

210
  }
211

212
  # Check for numbe rof classes larger than 1
213
  if (num_class == 1L) {
214

215
    # First interpretation element
216
    tree_interpretation_dt <- tree_interpretation[[1L]]
217

218
  } else {
219

220
    # Full interpretation elements
221
    tree_interpretation_dt <- Reduce(
222
      f = function(x, y) {
223
224
225
226
        merge(x, y, by = "Feature", all = TRUE)
      }
      , x = tree_interpretation
    )
227

228
    # Loop throughout each tree
229
    for (j in 2L:ncol(tree_interpretation_dt)) {
230

231
      data.table::set(
232
        x = tree_interpretation_dt
233
234
        , i = which(is.na(tree_interpretation_dt[[j]]))
        , j = j
235
        , value = 0.0
236
      )
237

238
    }
239

240
  }
241

242
243
  return(tree_interpretation_dt)
}