lgb.interprete.R 6.29 KB
Newer Older
1
2
3
#' @name lgb.interprete
#' @title Compute feature contribution of prediction
#' @description Computes feature contribution components of rawscore prediction.
4
5
#' @param model object of class \code{lgb.Booster}.
#' @param data a matrix object or a dgCMatrix object.
6
#' @param idxset an integer vector of indices of rows needed.
7
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration.
8
#'
9
10
11
#' @return For regression, binary classification and lambdarank model, a \code{list} of \code{data.table}
#'         with the following columns:
#'         \itemize{
12
13
#'             \item{\code{Feature}: Feature names in the model.}
#'             \item{\code{Contribution}: The total contribution of this feature's splits.}
14
15
16
#'         }
#'         For multiclass classification, a \code{list} of \code{data.table} with the Feature column and
#'         Contribution columns to each class.
17
#'
18
#' @examples
19
#' \dontrun{
20
#' Logit <- function(x) log(x / (1.0 - x))
21
#' data(agaricus.train, package = "lightgbm")
22
23
24
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label)))
25
#' data(agaricus.test, package = "lightgbm")
26
#' test <- agaricus.test
27
#'
28
29
#' params <- list(
#'     objective = "binary"
30
#'     , learning_rate = 0.1
31
32
33
#'     , max_depth = -1L
#'     , min_data_in_leaf = 1L
#'     , min_sum_hessian_in_leaf = 1.0
34
#' )
35
36
37
38
39
#' model <- lgb.train(
#'     params = params
#'     , data = dtrain
#'     , nrounds = 3L
#' )
40
#'
41
#' tree_interpretation <- lgb.interprete(model, test$data, 1L:5L)
42
#' }
43
#' @importFrom data.table as.data.table
44
#' @export
45
46
47
48
lgb.interprete <- function(model,
                           data,
                           idxset,
                           num_iteration = NULL) {
49

50
  # Get tree model
51
  tree_dt <- lgb.model.dt.tree(model, num_iteration)
52

53
  # Check number of classes
54
  num_class <- model$.__enclos_env__$private$num_class
55

56
  # Get vector list
57
  tree_interpretation_dt_list <- vector(mode = "list", length = length(idxset))
58

59
  # Get parsed predictions of data
60
61
62
63
64
65
66
67
68
69
70
71
  pred_mat <- t(
    model$predict(
      data[idxset, , drop = FALSE]
      , num_iteration = num_iteration
      , predleaf = TRUE
    )
  )
  leaf_index_dt <- data.table::as.data.table(pred_mat)
  leaf_index_mat_list <- lapply(
    X = leaf_index_dt
    , FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE)
  )
72

73
  # Get list of trees
74
75
  tree_index_mat_list <- lapply(
    X = leaf_index_mat_list
76
77
    , FUN = function(x) {
      matrix(seq_len(length(x)) - 1L, ncol = num_class, byrow = TRUE)
78
79
    }
  )
80

81
  # Sequence over idxset
82
  for (i in seq_along(idxset)) {
83
84
85
86
87
88
    tree_interpretation_dt_list[[i]] <- single.row.interprete(
      tree_dt
      , num_class
      , tree_index_mat_list[[i]]
      , leaf_index_mat_list[[i]]
    )
89
  }
90

91
  # Return interpretation list
92
  return(tree_interpretation_dt_list)
93

94
95
}

96
#' @importFrom data.table data.table
97
98
99
single.tree.interprete <- function(tree_dt,
                                   tree_id,
                                   leaf_id) {
100

101
  # Match tree id
102
  single_tree_dt <- tree_dt[tree_index == tree_id, ]
103

104
  # Get leaves
105
  leaf_dt <- single_tree_dt[leaf_index == leaf_id, .(leaf_index, leaf_parent, leaf_value)]
106

107
  # Get nodes
108
  node_dt <- single_tree_dt[!is.na(split_index), .(split_index, split_feature, node_parent, internal_value)]
109

110
  # Prepare sequences
111
112
  feature_seq <- character(0L)
  value_seq <- numeric(0L)
113

114
  # Get to root from leaf
115
  leaf_to_root <- function(parent_id, current_value) {
116

117
    # Store value
118
    value_seq <<- c(current_value, value_seq)
119

120
    # Check for null parent id
121
    if (!is.na(parent_id)) {
122

123
      # Not null means existing node
124
125
126
      this_node <- node_dt[split_index == parent_id, ]
      feature_seq <<- c(this_node[["split_feature"]], feature_seq)
      leaf_to_root(this_node[["node_parent"]], this_node[["internal_value"]])
127

128
    }
129

130
  }
131

132
  # Perform leaf to root conversion
133
  leaf_to_root(leaf_dt[["leaf_parent"]], leaf_dt[["leaf_value"]])
134

135
  # Return formatted data.table
136
137
138
139
  data.table::data.table(
    Feature = feature_seq
    , Contribution = diff.default(value_seq)
  )
140

141
142
}

143
#' @importFrom data.table := rbindlist setorder
144
145
146
multiple.tree.interprete <- function(tree_dt,
                                     tree_index,
                                     leaf_index) {
147

148
  # Apply each trees
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
  interp_dt <- data.table::rbindlist(
    l = mapply(
      FUN = single.tree.interprete
      , tree_id = tree_index
      , leaf_id = leaf_index
      , MoreArgs = list(
        tree_dt = tree_dt
      )
      , SIMPLIFY = FALSE
      , USE.NAMES = TRUE
    )
    , use.names = TRUE
  )

  interp_dt <- interp_dt[, .(Contribution = sum(Contribution)), by = "Feature"]

  # Sort features in descending order by contribution
  interp_dt[, abs_contribution := abs(Contribution)]
  data.table::setorder(
    x = interp_dt
    , -abs_contribution
  )

  # Drop absolute value of contribution (only needed for sorting)
  interp_dt[, abs_contribution := NULL]

  return(interp_dt)
176

177
178
}

179
#' @importFrom data.table set setnames
180
single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index_mat) {
181

182
  # Prepare vector list
183
  tree_interpretation <- vector(mode = "list", length = num_class)
184

185
  # Loop throughout each class
186
  for (i in seq_len(num_class)) {
187

188
189
    next_interp_dt <- multiple.tree.interprete(
      tree_dt = tree_dt
190
191
      , tree_index = tree_index_mat[, i]
      , leaf_index = leaf_index_mat[, i]
192
193
    )

194
    if (num_class > 1L) {
195
196
197
      data.table::setnames(
        next_interp_dt
        , old = "Contribution"
198
        , new = paste("Class", i - 1L)
199
      )
200
    }
201

202
203
    tree_interpretation[[i]] <- next_interp_dt

204
  }
205

206
  # Check for numbe rof classes larger than 1
207
  if (num_class == 1L) {
208

209
    # First interpretation element
210
    tree_interpretation_dt <- tree_interpretation[[1L]]
211

212
  } else {
213

214
    # Full interpretation elements
215
    tree_interpretation_dt <- Reduce(
216
      f = function(x, y) {
217
218
219
220
        merge(x, y, by = "Feature", all = TRUE)
      }
      , x = tree_interpretation
    )
221

222
    # Loop throughout each tree
223
    for (j in 2L:ncol(tree_interpretation_dt)) {
224

225
226
227
228
      data.table::set(
        tree_interpretation_dt
        , i = which(is.na(tree_interpretation_dt[[j]]))
        , j = j
229
        , value = 0.0
230
      )
231

232
    }
233

234
  }
235

236
  # Return interpretation tree
237
238
  return(tree_interpretation_dt)
}