lgb.interprete.R 6.3 KB
Newer Older
1
2
3
#' @name lgb.interprete
#' @title Compute feature contribution of prediction
#' @description Computes feature contribution components of rawscore prediction.
4
5
#' @param model object of class \code{lgb.Booster}.
#' @param data a matrix object or a dgCMatrix object.
6
#' @param idxset an integer vector of indices of rows needed.
7
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration.
8
#'
9
10
11
#' @return For regression, binary classification and lambdarank model, a \code{list} of \code{data.table}
#'         with the following columns:
#'         \itemize{
12
13
#'             \item{\code{Feature}: Feature names in the model.}
#'             \item{\code{Contribution}: The total contribution of this feature's splits.}
14
15
16
#'         }
#'         For multiclass classification, a \code{list} of \code{data.table} with the Feature column and
#'         Contribution columns to each class.
17
#'
18
#' @examples
19
20
#' Sigmoid <- function(x) 1.0 / (1.0 + exp(-x))
#' Logit <- function(x) log(x / (1.0 - x))
21
#' data(agaricus.train, package = "lightgbm")
22
23
24
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label)))
25
#' data(agaricus.test, package = "lightgbm")
26
#' test <- agaricus.test
27
#'
28
29
30
#' params <- list(
#'     objective = "binary"
#'     , learning_rate = 0.01
31
32
33
34
#'     , num_leaves = 63L
#'     , max_depth = -1L
#'     , min_data_in_leaf = 1L
#'     , min_sum_hessian_in_leaf = 1.0
35
#' )
36
#' model <- lgb.train(params, dtrain, 10L)
37
#'
38
#' tree_interpretation <- lgb.interprete(model, test$data, 1L:5L)
39
#'
40
#' @importFrom data.table as.data.table
41
#' @export
42
43
44
45
lgb.interprete <- function(model,
                           data,
                           idxset,
                           num_iteration = NULL) {
46

47
  # Get tree model
48
  tree_dt <- lgb.model.dt.tree(model, num_iteration)
49

50
  # Check number of classes
51
  num_class <- model$.__enclos_env__$private$num_class
52

53
  # Get vector list
54
  tree_interpretation_dt_list <- vector(mode = "list", length = length(idxset))
55

56
  # Get parsed predictions of data
57
58
59
60
61
62
63
64
65
66
67
68
  pred_mat <- t(
    model$predict(
      data[idxset, , drop = FALSE]
      , num_iteration = num_iteration
      , predleaf = TRUE
    )
  )
  leaf_index_dt <- data.table::as.data.table(pred_mat)
  leaf_index_mat_list <- lapply(
    X = leaf_index_dt
    , FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE)
  )
69

70
  # Get list of trees
71
72
  tree_index_mat_list <- lapply(
    X = leaf_index_mat_list
73
74
    , FUN = function(x) {
      matrix(seq_len(length(x)) - 1L, ncol = num_class, byrow = TRUE)
75
76
    }
  )
77

78
  # Sequence over idxset
79
  for (i in seq_along(idxset)) {
80
81
82
83
84
85
    tree_interpretation_dt_list[[i]] <- single.row.interprete(
      tree_dt
      , num_class
      , tree_index_mat_list[[i]]
      , leaf_index_mat_list[[i]]
    )
86
  }
87

88
  # Return interpretation list
89
  return(tree_interpretation_dt_list)
90

91
92
}

93
#' @importFrom data.table data.table
94
95
96
single.tree.interprete <- function(tree_dt,
                                   tree_id,
                                   leaf_id) {
97

98
  # Match tree id
99
  single_tree_dt <- tree_dt[tree_index == tree_id, ]
100

101
  # Get leaves
102
  leaf_dt <- single_tree_dt[leaf_index == leaf_id, .(leaf_index, leaf_parent, leaf_value)]
103

104
  # Get nodes
105
  node_dt <- single_tree_dt[!is.na(split_index), .(split_index, split_feature, node_parent, internal_value)]
106

107
  # Prepare sequences
108
109
  feature_seq <- character(0L)
  value_seq <- numeric(0L)
110

111
  # Get to root from leaf
112
  leaf_to_root <- function(parent_id, current_value) {
113

114
    # Store value
115
    value_seq <<- c(current_value, value_seq)
116

117
    # Check for null parent id
118
    if (!is.na(parent_id)) {
119

120
      # Not null means existing node
121
122
123
      this_node <- node_dt[split_index == parent_id, ]
      feature_seq <<- c(this_node[["split_feature"]], feature_seq)
      leaf_to_root(this_node[["node_parent"]], this_node[["internal_value"]])
124

125
    }
126

127
  }
128

129
  # Perform leaf to root conversion
130
  leaf_to_root(leaf_dt[["leaf_parent"]], leaf_dt[["leaf_value"]])
131

132
  # Return formatted data.table
133
134
135
136
  data.table::data.table(
    Feature = feature_seq
    , Contribution = diff.default(value_seq)
  )
137

138
139
}

140
#' @importFrom data.table := rbindlist setorder
141
142
143
multiple.tree.interprete <- function(tree_dt,
                                     tree_index,
                                     leaf_index) {
144

145
  # Apply each trees
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
  interp_dt <- data.table::rbindlist(
    l = mapply(
      FUN = single.tree.interprete
      , tree_id = tree_index
      , leaf_id = leaf_index
      , MoreArgs = list(
        tree_dt = tree_dt
      )
      , SIMPLIFY = FALSE
      , USE.NAMES = TRUE
    )
    , use.names = TRUE
  )

  interp_dt <- interp_dt[, .(Contribution = sum(Contribution)), by = "Feature"]

  # Sort features in descending order by contribution
  interp_dt[, abs_contribution := abs(Contribution)]
  data.table::setorder(
    x = interp_dt
    , -abs_contribution
  )

  # Drop absolute value of contribution (only needed for sorting)
  interp_dt[, abs_contribution := NULL]

  return(interp_dt)
173

174
175
}

176
#' @importFrom data.table set setnames
177
single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index_mat) {
178

179
  # Prepare vector list
180
  tree_interpretation <- vector(mode = "list", length = num_class)
181

182
  # Loop throughout each class
183
  for (i in seq_len(num_class)) {
184

185
186
    next_interp_dt <- multiple.tree.interprete(
      tree_dt = tree_dt
187
188
      , tree_index = tree_index_mat[, i]
      , leaf_index = leaf_index_mat[, i]
189
190
    )

191
    if (num_class > 1L) {
192
193
194
      data.table::setnames(
        next_interp_dt
        , old = "Contribution"
195
        , new = paste("Class", i - 1L)
196
      )
197
    }
198

199
200
    tree_interpretation[[i]] <- next_interp_dt

201
  }
202

203
  # Check for numbe rof classes larger than 1
204
  if (num_class == 1L) {
205

206
    # First interpretation element
207
    tree_interpretation_dt <- tree_interpretation[[1L]]
208

209
  } else {
210

211
    # Full interpretation elements
212
    tree_interpretation_dt <- Reduce(
213
      f = function(x, y) {
214
215
216
217
        merge(x, y, by = "Feature", all = TRUE)
      }
      , x = tree_interpretation
    )
218

219
    # Loop throughout each tree
220
    for (j in 2L:ncol(tree_interpretation_dt)) {
221

222
223
224
225
      data.table::set(
        tree_interpretation_dt
        , i = which(is.na(tree_interpretation_dt[[j]]))
        , j = j
226
        , value = 0.0
227
      )
228

229
    }
230

231
  }
232

233
  # Return interpretation tree
234
235
  return(tree_interpretation_dt)
}