lgb.interprete.R 6.2 KB
Newer Older
1
2
3
#' @name lgb.interprete
#' @title Compute feature contribution of prediction
#' @description Computes feature contribution components of rawscore prediction.
4
5
#' @param model object of class \code{lgb.Booster}.
#' @param data a matrix object or a dgCMatrix object.
6
#' @param idxset an integer vector of indices of rows needed.
7
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration.
8
#'
9
10
11
#' @return For regression, binary classification and lambdarank model, a \code{list} of \code{data.table}
#'         with the following columns:
#'         \itemize{
12
13
#'             \item{\code{Feature}: Feature names in the model.}
#'             \item{\code{Contribution}: The total contribution of this feature's splits.}
14
15
16
#'         }
#'         For multiclass classification, a \code{list} of \code{data.table} with the Feature column and
#'         Contribution columns to each class.
17
#'
18
#' @examples
19
#' \dontrun{
20
#' Logit <- function(x) log(x / (1.0 - x))
21
#' data(agaricus.train, package = "lightgbm")
22
23
24
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label)))
25
#' data(agaricus.test, package = "lightgbm")
26
#' test <- agaricus.test
27
#'
28
29
#' params <- list(
#'     objective = "binary"
30
#'     , learning_rate = 0.1
31
32
33
#'     , max_depth = -1L
#'     , min_data_in_leaf = 1L
#'     , min_sum_hessian_in_leaf = 1.0
34
#' )
35
36
37
38
39
#' model <- lgb.train(
#'     params = params
#'     , data = dtrain
#'     , nrounds = 3L
#' )
40
#'
41
#' tree_interpretation <- lgb.interprete(model, test$data, 1L:5L)
42
#' }
43
#' @importFrom data.table as.data.table
44
#' @export
45
46
47
48
lgb.interprete <- function(model,
                           data,
                           idxset,
                           num_iteration = NULL) {
49

50
  # Get tree model
51
  tree_dt <- lgb.model.dt.tree(model, num_iteration)
52

53
  # Check number of classes
54
  num_class <- model$.__enclos_env__$private$num_class
55

56
  # Get vector list
57
  tree_interpretation_dt_list <- vector(mode = "list", length = length(idxset))
58

59
  # Get parsed predictions of data
60
61
62
63
64
65
66
67
68
69
70
71
  pred_mat <- t(
    model$predict(
      data[idxset, , drop = FALSE]
      , num_iteration = num_iteration
      , predleaf = TRUE
    )
  )
  leaf_index_dt <- data.table::as.data.table(pred_mat)
  leaf_index_mat_list <- lapply(
    X = leaf_index_dt
    , FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE)
  )
72

73
  # Get list of trees
74
75
  tree_index_mat_list <- lapply(
    X = leaf_index_mat_list
76
77
    , FUN = function(x) {
      matrix(seq_len(length(x)) - 1L, ncol = num_class, byrow = TRUE)
78
79
    }
  )
80

81
  # Sequence over idxset
82
  for (i in seq_along(idxset)) {
83
84
85
86
87
88
    tree_interpretation_dt_list[[i]] <- single.row.interprete(
      tree_dt
      , num_class
      , tree_index_mat_list[[i]]
      , leaf_index_mat_list[[i]]
    )
89
  }
90

91
  return(tree_interpretation_dt_list)
92

93
94
}

95
#' @importFrom data.table data.table
96
97
98
single.tree.interprete <- function(tree_dt,
                                   tree_id,
                                   leaf_id) {
99

100
  # Match tree id
101
  single_tree_dt <- tree_dt[tree_index == tree_id, ]
102

103
  # Get leaves
104
  leaf_dt <- single_tree_dt[leaf_index == leaf_id, .(leaf_index, leaf_parent, leaf_value)]
105

106
  # Get nodes
107
  node_dt <- single_tree_dt[!is.na(split_index), .(split_index, split_feature, node_parent, internal_value)]
108

109
  # Prepare sequences
110
111
  feature_seq <- character(0L)
  value_seq <- numeric(0L)
112

113
  # Get to root from leaf
114
  leaf_to_root <- function(parent_id, current_value) {
115

116
    # Store value
117
    value_seq <<- c(current_value, value_seq)
118

119
    # Check for null parent id
120
    if (!is.na(parent_id)) {
121

122
      # Not null means existing node
123
124
125
      this_node <- node_dt[split_index == parent_id, ]
      feature_seq <<- c(this_node[["split_feature"]], feature_seq)
      leaf_to_root(this_node[["node_parent"]], this_node[["internal_value"]])
126

127
    }
128

129
  }
130

131
  # Perform leaf to root conversion
132
  leaf_to_root(leaf_dt[["leaf_parent"]], leaf_dt[["leaf_value"]])
133

134
135
136
137
  data.table::data.table(
    Feature = feature_seq
    , Contribution = diff.default(value_seq)
  )
138

139
140
}

141
#' @importFrom data.table := rbindlist setorder
142
143
144
multiple.tree.interprete <- function(tree_dt,
                                     tree_index,
                                     leaf_index) {
145

146
  # Apply each trees
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
  interp_dt <- data.table::rbindlist(
    l = mapply(
      FUN = single.tree.interprete
      , tree_id = tree_index
      , leaf_id = leaf_index
      , MoreArgs = list(
        tree_dt = tree_dt
      )
      , SIMPLIFY = FALSE
      , USE.NAMES = TRUE
    )
    , use.names = TRUE
  )

  interp_dt <- interp_dt[, .(Contribution = sum(Contribution)), by = "Feature"]

  # Sort features in descending order by contribution
  interp_dt[, abs_contribution := abs(Contribution)]
  data.table::setorder(
    x = interp_dt
    , -abs_contribution
  )

  # Drop absolute value of contribution (only needed for sorting)
  interp_dt[, abs_contribution := NULL]

  return(interp_dt)
174

175
176
}

177
#' @importFrom data.table set setnames
178
single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index_mat) {
179

180
  # Prepare vector list
181
  tree_interpretation <- vector(mode = "list", length = num_class)
182

183
  # Loop throughout each class
184
  for (i in seq_len(num_class)) {
185

186
187
    next_interp_dt <- multiple.tree.interprete(
      tree_dt = tree_dt
188
189
      , tree_index = tree_index_mat[, i]
      , leaf_index = leaf_index_mat[, i]
190
191
    )

192
    if (num_class > 1L) {
193
194
195
      data.table::setnames(
        next_interp_dt
        , old = "Contribution"
196
        , new = paste("Class", i - 1L)
197
      )
198
    }
199

200
201
    tree_interpretation[[i]] <- next_interp_dt

202
  }
203

204
  # Check for numbe rof classes larger than 1
205
  if (num_class == 1L) {
206

207
    # First interpretation element
208
    tree_interpretation_dt <- tree_interpretation[[1L]]
209

210
  } else {
211

212
    # Full interpretation elements
213
    tree_interpretation_dt <- Reduce(
214
      f = function(x, y) {
215
216
217
218
        merge(x, y, by = "Feature", all = TRUE)
      }
      , x = tree_interpretation
    )
219

220
    # Loop throughout each tree
221
    for (j in 2L:ncol(tree_interpretation_dt)) {
222

223
224
225
226
      data.table::set(
        tree_interpretation_dt
        , i = which(is.na(tree_interpretation_dt[[j]]))
        , j = j
227
        , value = 0.0
228
      )
229

230
    }
231

232
  }
233

234
235
  return(tree_interpretation_dt)
}