lgb.interprete.R 6.39 KB
Newer Older
1
2
3
#' @name lgb.interprete
#' @title Compute feature contribution of prediction
#' @description Computes feature contribution components of rawscore prediction.
4
5
#' @param model object of class \code{lgb.Booster}.
#' @param data a matrix object or a dgCMatrix object.
6
#' @param idxset an integer vector of indices of rows needed.
7
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration.
8
#'
9
10
11
#' @return For regression, binary classification and lambdarank model, a \code{list} of \code{data.table}
#'         with the following columns:
#'         \itemize{
12
13
#'             \item{\code{Feature}: Feature names in the model.}
#'             \item{\code{Contribution}: The total contribution of this feature's splits.}
14
15
16
#'         }
#'         For multiclass classification, a \code{list} of \code{data.table} with the Feature column and
#'         Contribution columns to each class.
17
#'
18
#' @examples
19
#' \donttest{
20
#' Logit <- function(x) log(x / (1.0 - x))
21
#' data(agaricus.train, package = "lightgbm")
22
23
24
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label)))
25
#' data(agaricus.test, package = "lightgbm")
26
#' test <- agaricus.test
27
#'
28
29
#' params <- list(
#'     objective = "binary"
30
#'     , learning_rate = 0.1
31
32
33
#'     , max_depth = -1L
#'     , min_data_in_leaf = 1L
#'     , min_sum_hessian_in_leaf = 1.0
34
#' )
35
36
37
38
39
#' model <- lgb.train(
#'     params = params
#'     , data = dtrain
#'     , nrounds = 3L
#' )
40
#'
41
#' tree_interpretation <- lgb.interprete(model, test$data, 1L:5L)
42
#' }
43
#' @importFrom data.table as.data.table
44
#' @export
45
46
47
48
lgb.interprete <- function(model,
                           data,
                           idxset,
                           num_iteration = NULL) {
49

50
  # Get tree model
51
  tree_dt <- lgb.model.dt.tree(model = model, num_iteration = num_iteration)
52

53
  # Check number of classes
54
  num_class <- model$.__enclos_env__$private$num_class
55

56
  # Get vector list
57
  tree_interpretation_dt_list <- vector(mode = "list", length = length(idxset))
58

59
  # Get parsed predictions of data
60
61
  pred_mat <- t(
    model$predict(
62
      data = data[idxset, , drop = FALSE]
63
64
65
66
      , num_iteration = num_iteration
      , predleaf = TRUE
    )
  )
67
  leaf_index_dt <- data.table::as.data.table(x = pred_mat)
68
69
70
71
  leaf_index_mat_list <- lapply(
    X = leaf_index_dt
    , FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE)
  )
72

73
  # Get list of trees
74
75
  tree_index_mat_list <- lapply(
    X = leaf_index_mat_list
76
77
    , FUN = function(x) {
      matrix(seq_len(length(x)) - 1L, ncol = num_class, byrow = TRUE)
78
79
    }
  )
80

81
  # Sequence over idxset
82
  for (i in seq_along(idxset)) {
83
    tree_interpretation_dt_list[[i]] <- single.row.interprete(
84
85
86
87
      tree_dt = tree_dt
      , num_class = num_class
      , tree_index_mat = tree_index_mat_list[[i]]
      , leaf_index_mat = leaf_index_mat_list[[i]]
88
    )
89
  }
90

91
  return(tree_interpretation_dt_list)
92

93
94
}

95
#' @importFrom data.table data.table
96
97
98
single.tree.interprete <- function(tree_dt,
                                   tree_id,
                                   leaf_id) {
99

100
  # Match tree id
101
  single_tree_dt <- tree_dt[tree_index == tree_id, ]
102

103
  # Get leaves
104
  leaf_dt <- single_tree_dt[leaf_index == leaf_id, .(leaf_index, leaf_parent, leaf_value)]
105

106
  # Get nodes
107
  node_dt <- single_tree_dt[!is.na(split_index), .(split_index, split_feature, node_parent, internal_value)]
108

109
  # Prepare sequences
110
111
  feature_seq <- character(0L)
  value_seq <- numeric(0L)
112

113
  # Get to root from leaf
114
  leaf_to_root <- function(parent_id, current_value) {
115

116
    # Store value
117
    value_seq <<- c(current_value, value_seq)
118

119
    # Check for null parent id
120
    if (!is.na(parent_id)) {
121

122
      # Not null means existing node
123
124
      this_node <- node_dt[split_index == parent_id, ]
      feature_seq <<- c(this_node[["split_feature"]], feature_seq)
125
126
127
128
      leaf_to_root(
        parent_id = this_node[["node_parent"]]
        , current_value = this_node[["internal_value"]]
      )
129

130
    }
131

132
  }
133

134
  # Perform leaf to root conversion
135
136
137
138
  leaf_to_root(
    parent_id = leaf_dt[["leaf_parent"]]
    , current_value = leaf_dt[["leaf_value"]]
  )
139

140
141
142
143
  data.table::data.table(
    Feature = feature_seq
    , Contribution = diff.default(value_seq)
  )
144

145
146
}

147
#' @importFrom data.table := rbindlist setorder
148
149
150
multiple.tree.interprete <- function(tree_dt,
                                     tree_index,
                                     leaf_index) {
151

152
  # Apply each trees
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
  interp_dt <- data.table::rbindlist(
    l = mapply(
      FUN = single.tree.interprete
      , tree_id = tree_index
      , leaf_id = leaf_index
      , MoreArgs = list(
        tree_dt = tree_dt
      )
      , SIMPLIFY = FALSE
      , USE.NAMES = TRUE
    )
    , use.names = TRUE
  )

  interp_dt <- interp_dt[, .(Contribution = sum(Contribution)), by = "Feature"]

  # Sort features in descending order by contribution
  interp_dt[, abs_contribution := abs(Contribution)]
  data.table::setorder(
    x = interp_dt
    , -abs_contribution
  )

  # Drop absolute value of contribution (only needed for sorting)
  interp_dt[, abs_contribution := NULL]

  return(interp_dt)
180

181
182
}

183
#' @importFrom data.table set setnames
184
single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index_mat) {
185

186
  # Prepare vector list
187
  tree_interpretation <- vector(mode = "list", length = num_class)
188

189
  # Loop throughout each class
190
  for (i in seq_len(num_class)) {
191

192
193
    next_interp_dt <- multiple.tree.interprete(
      tree_dt = tree_dt
194
195
      , tree_index = tree_index_mat[, i]
      , leaf_index = leaf_index_mat[, i]
196
197
    )

198
    if (num_class > 1L) {
199
      data.table::setnames(
200
        x = next_interp_dt
201
        , old = "Contribution"
202
        , new = paste("Class", i - 1L)
203
      )
204
    }
205

206
207
    tree_interpretation[[i]] <- next_interp_dt

208
  }
209

210
  # Check for numbe rof classes larger than 1
211
  if (num_class == 1L) {
212

213
    # First interpretation element
214
    tree_interpretation_dt <- tree_interpretation[[1L]]
215

216
  } else {
217

218
    # Full interpretation elements
219
    tree_interpretation_dt <- Reduce(
220
      f = function(x, y) {
221
222
223
224
        merge(x, y, by = "Feature", all = TRUE)
      }
      , x = tree_interpretation
    )
225

226
    # Loop throughout each tree
227
    for (j in 2L:ncol(tree_interpretation_dt)) {
228

229
      data.table::set(
230
        x = tree_interpretation_dt
231
232
        , i = which(is.na(tree_interpretation_dt[[j]]))
        , j = j
233
        , value = 0.0
234
      )
235

236
    }
237

238
  }
239

240
241
  return(tree_interpretation_dt)
}