lgb.interprete.R 5.87 KB
Newer Older
1
#' Compute feature contribution of prediction
2
#' 
3
#' Computes feature contribution components of rawscore prediction.
4
#' 
5
6
7
8
#' @param model object of class \code{lgb.Booster}.
#' @param data a matrix object or a dgCMatrix object.
#' @param idxset a integer vector of indices of rows needed.
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration.
9
#' 
10
#' @return
11
#' 
12
13
14
15
16
17
#' For regression, binary classification and lambdarank model, a \code{list} of \code{data.table} with the following columns:
#' \itemize{
#'   \item \code{Feature} Feature names in the model.
#'   \item \code{Contribution} The total contribution of this feature's splits.
#' }
#' For multiclass classification, a \code{list} of \code{data.table} with the Feature column and Contribution columns to each class.
18
#' 
19
#' @examples
20
21
#' \dontrun{
#' library(lightgbm)
22
23
#' Sigmoid <- function(x) 1 / (1 + exp(-x))
#' Logit <- function(x) log(x / (1 - x))
24
#' data(agaricus.train, package = "lightgbm")
25
26
27
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label)))
28
#' data(agaricus.test, package = "lightgbm")
29
30
31
32
33
34
35
36
37
#' test <- agaricus.test
#'
#' params = list(objective = "binary",
#'               learning_rate = 0.01, num_leaves = 63, max_depth = -1,
#'               min_data_in_leaf = 1, min_sum_hessian_in_leaf = 1)
#'               model <- lgb.train(params, dtrain, 20)
#' model <- lgb.train(params, dtrain, 20)
#'
#' tree_interpretation <- lgb.interprete(model, test$data, 1:5)
38
39
#' }
#' 
40
41
#' @importFrom magrittr %>% %T>%
#' @export
42
43
44
45
46
47
lgb.interprete <- function(model,
                           data,
                           idxset,
                           num_iteration = NULL) {
  
  # Get tree model
48
  tree_dt <- lgb.model.dt.tree(model, num_iteration)
49
50
  
  # Check number of classes
51
  num_class <- model$.__enclos_env__$private$num_class
52
53
  
  # Get vector list
54
  tree_interpretation_dt_list <- vector(mode = "list", length = length(idxset))
55
56
  
  # Get parsed predictions of data
57
58
59
60
61
62
  leaf_index_mat_list <- model$predict(data[idxset, , drop = FALSE],
                                       num_iteration = num_iteration,
                                       predleaf = TRUE) %>%
    t(.) %>%
    data.table::as.data.table(.) %>%
    lapply(., FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE))
63
64
  
  # Get list of trees
65
66
  tree_index_mat_list <- lapply(leaf_index_mat_list,
                                FUN = function(x) matrix(seq_len(length(x)) - 1, ncol = num_class, byrow = TRUE))
67
68
  
  # Sequence over idxset
69
70
71
  for (i in seq_along(idxset)) {
    tree_interpretation_dt_list[[i]] <- single.row.interprete(tree_dt, num_class, tree_index_mat_list[[i]], leaf_index_mat_list[[i]])
  }
72
73
  
  # Return interpretation list
74
  return(tree_interpretation_dt_list)
75
  
76
77
}

78
79
80
81
82
single.tree.interprete <- function(tree_dt,
                                   tree_id,
                                   leaf_id) {
  
  # Match tree id
83
  single_tree_dt <- tree_dt[tree_index == tree_id, ]
84
85
  
  # Get leaves
86
  leaf_dt <- single_tree_dt[leaf_index == leaf_id, .(leaf_index, leaf_parent, leaf_value)]
87
88
  
  # Get nodes
89
  node_dt <- single_tree_dt[!is.na(split_index), .(split_index, split_feature, node_parent, internal_value)]
90
91
  
  # Prepare sequences
92
93
  feature_seq <- character(0)
  value_seq <- numeric(0)
94
95
  
  # Get to root from leaf
96
  leaf_to_root <- function(parent_id, current_value) {
97
98
    
    # Store value
99
    value_seq <<- c(current_value, value_seq)
100
101
    
    # Check for null parent id
102
    if (!is.na(parent_id)) {
103
104
      
      # Not null means existing node
105
106
107
      this_node <- node_dt[split_index == parent_id, ]
      feature_seq <<- c(this_node[["split_feature"]], feature_seq)
      leaf_to_root(this_node[["node_parent"]], this_node[["internal_value"]])
108
      
109
    }
110
    
111
  }
112
113
  
  # Perform leaf to root conversion
114
  leaf_to_root(leaf_dt[["leaf_parent"]], leaf_dt[["leaf_value"]])
115
116
  
  # Return formatted data.table
117
  data.table::data.table(Feature = feature_seq, Contribution = diff.default(value_seq))
118
  
119
120
}

121
122
123
124
125
multiple.tree.interprete <- function(tree_dt,
                                     tree_index,
                                     leaf_index) {
  
  # Apply each trees
126
127
128
129
130
131
132
  mapply(single.tree.interprete,
         tree_id = tree_index, leaf_id = leaf_index,
         MoreArgs = list(tree_dt = tree_dt),
         SIMPLIFY = FALSE, USE.NAMES = TRUE) %>%
    data.table::rbindlist(., use.names = TRUE) %>%
    magrittr::extract(., j = .(Contribution = sum(Contribution)), by = "Feature") %>%
    magrittr::extract(., i = order(abs(Contribution), decreasing = TRUE))
133
  
134
135
136
}

single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index_mat) {
137
138
  
  # Prepare vector list
139
  tree_interpretation <- vector(mode = "list", length = num_class)
140
141
  
  # Loop throughout each class
142
  for (i in seq_len(num_class)) {
143
144
145
146
    
    tree_interpretation[[i]] <- multiple.tree.interprete(tree_dt, tree_index_mat[,i], leaf_index_mat[,i]) %T>% {
      
      # Number of classes larger than 1 requires adjustment
147
148
149
150
      if (num_class > 1) {
        data.table::setnames(., old = "Contribution", new = paste("Class", i - 1))
      }
    }
151
    
152
  }
153
154
  
  # Check for numbe rof classes larger than 1
155
  if (num_class == 1) {
156
157
    
    # First interpretation element
158
    tree_interpretation_dt <- tree_interpretation[[1]]
159
    
160
  } else {
161
162
    
    # Full interpretation elements
163
164
    tree_interpretation_dt <- Reduce(f = function(x, y) merge(x, y, by = "Feature", all = TRUE),
                                     x = tree_interpretation)
165
166
    
    # Loop throughout each tree
167
    for (j in 2:ncol(tree_interpretation_dt)) {
168
      
169
170
171
172
      data.table::set(tree_interpretation_dt,
                      i = which(is.na(tree_interpretation_dt[[j]])),
                      j = j,
                      value = 0)
173
      
174
    }
175
    
176
  }
177
178
  
  # Return interpretation tree
179
180
  return(tree_interpretation_dt)
}