lgb.interprete.R 6.15 KB
Newer Older
1
#' Compute feature contribution of prediction
2
#'
3
#' Computes feature contribution components of rawscore prediction.
4
#'
5
6
#' @param model object of class \code{lgb.Booster}.
#' @param data a matrix object or a dgCMatrix object.
7
#' @param idxset an integer vector of indices of rows needed.
8
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration.
9
#'
10
#' @return
11
#'
12
13
14
15
16
17
#' For regression, binary classification and lambdarank model, a \code{list} of \code{data.table} with the following columns:
#' \itemize{
#'   \item \code{Feature} Feature names in the model.
#'   \item \code{Contribution} The total contribution of this feature's splits.
#' }
#' For multiclass classification, a \code{list} of \code{data.table} with the Feature column and Contribution columns to each class.
18
#'
19
20
21
#' @examples
#' Sigmoid <- function(x) 1 / (1 + exp(-x))
#' Logit <- function(x) log(x / (1 - x))
22
#' data(agaricus.train, package = "lightgbm")
23
24
25
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label)))
26
#' data(agaricus.test, package = "lightgbm")
27
#' test <- agaricus.test
28
#'
29
30
31
32
33
34
35
36
#' params <- list(
#'     objective = "binary"
#'     , learning_rate = 0.01
#'     , num_leaves = 63
#'     , max_depth = -1
#'     , min_data_in_leaf = 1
#'     , min_sum_hessian_in_leaf = 1
#' )
37
#' model <- lgb.train(params, dtrain, 10)
38
#'
39
#' tree_interpretation <- lgb.interprete(model, test$data, 1:5)
40
#'
41
#' @importFrom data.table as.data.table
42
#' @export
43
44
45
46
lgb.interprete <- function(model,
                           data,
                           idxset,
                           num_iteration = NULL) {
47

48
  # Get tree model
49
  tree_dt <- lgb.model.dt.tree(model, num_iteration)
50

51
  # Check number of classes
52
  num_class <- model$.__enclos_env__$private$num_class
53

54
  # Get vector list
55
  tree_interpretation_dt_list <- vector(mode = "list", length = length(idxset))
56

57
  # Get parsed predictions of data
58
59
60
61
62
63
64
65
66
67
68
69
  pred_mat <- t(
    model$predict(
      data[idxset, , drop = FALSE]
      , num_iteration = num_iteration
      , predleaf = TRUE
    )
  )
  leaf_index_dt <- data.table::as.data.table(pred_mat)
  leaf_index_mat_list <- lapply(
    X = leaf_index_dt
    , FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE)
  )
70

71
  # Get list of trees
72
73
  tree_index_mat_list <- lapply(leaf_index_mat_list,
                                FUN = function(x) matrix(seq_len(length(x)) - 1, ncol = num_class, byrow = TRUE))
74

75
  # Sequence over idxset
76
77
78
  for (i in seq_along(idxset)) {
    tree_interpretation_dt_list[[i]] <- single.row.interprete(tree_dt, num_class, tree_index_mat_list[[i]], leaf_index_mat_list[[i]])
  }
79

80
  # Return interpretation list
81
  return(tree_interpretation_dt_list)
82

83
84
}

85
#' @importFrom data.table data.table
86
87
88
single.tree.interprete <- function(tree_dt,
                                   tree_id,
                                   leaf_id) {
89

90
  # Match tree id
91
  single_tree_dt <- tree_dt[tree_index == tree_id, ]
92

93
  # Get leaves
94
  leaf_dt <- single_tree_dt[leaf_index == leaf_id, .(leaf_index, leaf_parent, leaf_value)]
95

96
  # Get nodes
97
  node_dt <- single_tree_dt[!is.na(split_index), .(split_index, split_feature, node_parent, internal_value)]
98

99
  # Prepare sequences
100
101
  feature_seq <- character(0)
  value_seq <- numeric(0)
102

103
  # Get to root from leaf
104
  leaf_to_root <- function(parent_id, current_value) {
105

106
    # Store value
107
    value_seq <<- c(current_value, value_seq)
108

109
    # Check for null parent id
110
    if (!is.na(parent_id)) {
111

112
      # Not null means existing node
113
114
115
      this_node <- node_dt[split_index == parent_id, ]
      feature_seq <<- c(this_node[["split_feature"]], feature_seq)
      leaf_to_root(this_node[["node_parent"]], this_node[["internal_value"]])
116

117
    }
118

119
  }
120

121
  # Perform leaf to root conversion
122
  leaf_to_root(leaf_dt[["leaf_parent"]], leaf_dt[["leaf_value"]])
123

124
  # Return formatted data.table
125
  data.table::data.table(Feature = feature_seq, Contribution = diff.default(value_seq))
126

127
128
}

129
#' @importFrom data.table := rbindlist setorder
130
131
132
multiple.tree.interprete <- function(tree_dt,
                                     tree_index,
                                     leaf_index) {
133

134
  # Apply each trees
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
  interp_dt <- data.table::rbindlist(
    l = mapply(
      FUN = single.tree.interprete
      , tree_id = tree_index
      , leaf_id = leaf_index
      , MoreArgs = list(
        tree_dt = tree_dt
      )
      , SIMPLIFY = FALSE
      , USE.NAMES = TRUE
    )
    , use.names = TRUE
  )

  interp_dt <- interp_dt[, .(Contribution = sum(Contribution)), by = "Feature"]

  # Sort features in descending order by contribution
  interp_dt[, abs_contribution := abs(Contribution)]
  data.table::setorder(
    x = interp_dt
    , -abs_contribution
  )

  # Drop absolute value of contribution (only needed for sorting)
  interp_dt[, abs_contribution := NULL]

  return(interp_dt)
162

163
164
}

165
#' @importFrom data.table set setnames
166
single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index_mat) {
167

168
  # Prepare vector list
169
  tree_interpretation <- vector(mode = "list", length = num_class)
170

171
  # Loop throughout each class
172
  for (i in seq_len(num_class)) {
173

174
175
176
177
178
179
180
181
182
183
184
185
    next_interp_dt <- multiple.tree.interprete(
      tree_dt = tree_dt
      , tree_index = tree_index_mat[,i]
      , leaf_index = leaf_index_mat[,i]
    )

    if (num_class > 1){
      data.table::setnames(
        next_interp_dt
        , old = "Contribution"
        , new = paste("Class", i - 1)
      )
186
    }
187

188
189
    tree_interpretation[[i]] <- next_interp_dt

190
  }
191

192
  # Check for numbe rof classes larger than 1
193
  if (num_class == 1) {
194

195
    # First interpretation element
196
    tree_interpretation_dt <- tree_interpretation[[1]]
197

198
  } else {
199

200
    # Full interpretation elements
201
202
    tree_interpretation_dt <- Reduce(f = function(x, y) merge(x, y, by = "Feature", all = TRUE),
                                     x = tree_interpretation)
203

204
    # Loop throughout each tree
205
    for (j in 2:ncol(tree_interpretation_dt)) {
206

207
208
209
210
      data.table::set(tree_interpretation_dt,
                      i = which(is.na(tree_interpretation_dt[[j]])),
                      j = j,
                      value = 0)
211

212
    }
213

214
  }
215

216
  # Return interpretation tree
217
218
  return(tree_interpretation_dt)
}