callback.R 7.04 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
CB_ENV <- R6Class(
  "lgb.cb_env",
3
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
4
  public = list(
5
6
7
8
9
10
11
12
    model           = NULL,
    iteration       = NULL,
    begin_iteration = NULL,
    end_iteration   = NULL,
    eval_list       = list(),
    eval_err_list   = list(),
    best_iter       = -1,
    met_early_stop  = FALSE
Guolin Ke's avatar
Guolin Ke committed
13
14
15
16
  )
)

cb.reset.parameters <- function(new_params) {
17
18
  if (!is.list(new_params)) { stop(sQuote("new_params"), " must be a list") }
  pnames  <- gsub("\\.", "_", names(new_params))
Guolin Ke's avatar
Guolin Ke committed
19
  nrounds <- NULL
20

Guolin Ke's avatar
Guolin Ke committed
21
22
23
  # run some checks in the begining
  init <- function(env) {
    nrounds <<- env$end_iteration - env$begin_iteration + 1
24
25
26

    if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) }

Guolin Ke's avatar
Guolin Ke committed
27
28
    # Some parameters are not allowed to be changed,
    # since changing them would simply wreck some chaos
29
30
31
32
33
    not_allowed <- c("num_class", "metric", "boosting_type")
    if (any(pnames %in% not_allowed)) {
      stop("Parameters ", paste0(pnames[pnames %in% not_allowed], collapse = ", "), " cannot be changed during boosting")
    }

Guolin Ke's avatar
Guolin Ke committed
34
35
36
37
    for (n in pnames) {
      p <- new_params[[n]]
      if (is.function(p)) {
        if (length(formals(p)) != 2)
38
          stop("Parameter ", sQuote(n), " is a function but not of two arguments")
Guolin Ke's avatar
Guolin Ke committed
39
40
      } else if (is.numeric(p) || is.character(p)) {
        if (length(p) != nrounds)
41
          stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds"))
Guolin Ke's avatar
Guolin Ke committed
42
      } else {
43
        stop("Parameter ", sQuote(n), " is not a function or a vector")
Guolin Ke's avatar
Guolin Ke committed
44
45
46
      }
    }
  }
47

Guolin Ke's avatar
Guolin Ke committed
48
  callback <- function(env) {
49
    if (is.null(nrounds)) { init(env) }
Guolin Ke's avatar
Guolin Ke committed
50
51
    i <- env$iteration - env$begin_iteration
    pars <- lapply(new_params, function(p) {
52
      if (is.function(p)) { return(p(i, nrounds)) }
Guolin Ke's avatar
Guolin Ke committed
53
54
55
      p[i]
    })
    # to-do check pars
56
    if (!is.null(env$model)) { env$model$reset_parameter(pars) }
Guolin Ke's avatar
Guolin Ke committed
57
  }
58
59

  attr(callback, 'call')             <- match.call()
Guolin Ke's avatar
Guolin Ke committed
60
  attr(callback, 'is_pre_iteration') <- TRUE
61
62
  attr(callback, 'name')             <- 'cb.reset.parameters'
  callback
Guolin Ke's avatar
Guolin Ke committed
63
64
65
}

# Format the evaluation metric string
66
67
format.eval.string <- function(eval_res, eval_err = NULL) {
  if (is.null(eval_res) || length(eval_res) == 0) { stop('no evaluation results') }
Guolin Ke's avatar
Guolin Ke committed
68
  if (!is.null(eval_err)) {
69
    sprintf('%s\'s %s:%g+%g', eval_res$data_name, eval_res$name, eval_res$value, eval_err)
Guolin Ke's avatar
Guolin Ke committed
70
  } else {
71
    sprintf('%s\'s %s:%g', eval_res$data_name, eval_res$name, eval_res$value)
Guolin Ke's avatar
Guolin Ke committed
72
73
74
  }
}

75
76
77
merge.eval.string <- function(env) {
  if (length(env$eval_list) <= 0) { return("") }
  msg <- list(sprintf('[%d]:', env$iteration))
Guolin Ke's avatar
Guolin Ke committed
78
  is_eval_err <- FALSE
79
80
  if (length(env$eval_err_list) > 0) { is_eval_err <- TRUE }
  for (j in seq_along(env$eval_list)) {
Guolin Ke's avatar
Guolin Ke committed
81
    eval_err <- NULL
82
83
    if (is_eval_err) { eval_err <- env$eval_err_list[[j]] }
    msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err))
Guolin Ke's avatar
Guolin Ke committed
84
  }
85
  paste0(msg, collapse='\t')
Guolin Ke's avatar
Guolin Ke committed
86
87
}

88
89
90
cb.print.evaluation <- function(period = 1){
  callback <- function(env) {
    if (period > 0) {
Guolin Ke's avatar
Guolin Ke committed
91
      i <- env$iteration
92
      if ( (i - 1) %% period == 0
Guolin Ke's avatar
Guolin Ke committed
93
         | i == env$begin_iteration
94
         | i == env$end_iteration ) {
Guolin Ke's avatar
Guolin Ke committed
95
96
97
98
99
100
        cat(merge.eval.string(env), "\n")
      }
    }
  }
  attr(callback, 'call') <- match.call()
  attr(callback, 'name') <- 'cb.print.evaluation'
101
  callback
Guolin Ke's avatar
Guolin Ke committed
102
103
104
}

cb.record.evaluation <- function() {
105
106
  callback <- function(env) {
    if (length(env$eval_list) <= 0) { return() }
Guolin Ke's avatar
Guolin Ke committed
107
    is_eval_err <- FALSE
108
109
110
    if (length(env$eval_err_list) > 0) { is_eval_err <- TRUE }
    if (length(env$model$record_evals) == 0) {
      for (j in seq_along(env$eval_list)) {
Guolin Ke's avatar
Guolin Ke committed
111
        data_name <- env$eval_list[[j]]$data_name
112
        name      <- env$eval_list[[j]]$name
Guolin Ke's avatar
Guolin Ke committed
113
        env$model$record_evals$start_iter <- env$begin_iteration
114
        if (is.null(env$model$record_evals[[data_name]])) {
Guolin Ke's avatar
Guolin Ke committed
115
116
          env$model$record_evals[[data_name]] <- list()
        }
117
118
        env$model$record_evals[[data_name]][[name]]          <- list()
        env$model$record_evals[[data_name]][[name]]$eval     <- list()
Guolin Ke's avatar
Guolin Ke committed
119
120
121
        env$model$record_evals[[data_name]][[name]]$eval_err <- list()
      }
    }
122
    for (j in seq_along(env$eval_list)) {
Guolin Ke's avatar
Guolin Ke committed
123
124
      eval_res <- env$eval_list[[j]]
      eval_err <- NULL
125
      if (is_eval_err) { eval_err <- env$eval_err_list[[j]] }
Guolin Ke's avatar
Guolin Ke committed
126
      data_name <- eval_res$data_name
127
      name      <- eval_res$name
Guolin Ke's avatar
Guolin Ke committed
128
129
130
131
132
133
      env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value)
      env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err)
    }
  }
  attr(callback, 'call') <- match.call()
  attr(callback, 'name') <- 'cb.record.evaluation'
134
  callback
Guolin Ke's avatar
Guolin Ke committed
135
136
}

137
cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
Guolin Ke's avatar
Guolin Ke committed
138
139
  # state variables
  factor_to_bigger_better <- NULL
140
141
142
143
  best_iter               <- NULL
  best_score              <- NULL
  best_msg                <- NULL
  eval_len                <- NULL
Guolin Ke's avatar
Guolin Ke committed
144
  init <- function(env) {
145
146
    eval_len <<- length(env$eval_list)
    if (eval_len == 0) {
Guolin Ke's avatar
Guolin Ke committed
147
      stop("For early stopping, valids must have at least one element")
148
149
150
151
152
    }

    if (isTRUE(verbose)) {
      cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = '')
    }
Guolin Ke's avatar
Guolin Ke committed
153
154

    factor_to_bigger_better <<- rep(1.0, eval_len)
155
    best_iter  <<- rep(-1, eval_len)
Guolin Ke's avatar
Guolin Ke committed
156
    best_score <<- rep(-Inf, eval_len)
157
158
    best_msg   <<- list()
    for (i in seq_len(eval_len)) {
Guolin Ke's avatar
Guolin Ke committed
159
      best_msg <<- c(best_msg, "")
160
      if (!env$eval_list[[i]]$higher_better) {
Guolin Ke's avatar
Guolin Ke committed
161
162
163
164
        factor_to_bigger_better[i] <<- -1.0
      }
    }
  }
165

Guolin Ke's avatar
Guolin Ke committed
166
  callback <- function(env, finalize = FALSE) {
167
    if (is.null(eval_len)) { init(env) }
Guolin Ke's avatar
Guolin Ke committed
168
    cur_iter <- env$iteration
169
    for (i in seq_len(eval_len)) {
Guolin Ke's avatar
Guolin Ke committed
170
      score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
171
      if (score > best_score[i]) {
Guolin Ke's avatar
Guolin Ke committed
172
        best_score[i] <<- score
173
174
        best_iter[i]  <<- cur_iter
        if (verbose) {
Guolin Ke's avatar
Guolin Ke committed
175
176
177
          best_msg[[i]] <<- as.character(merge.eval.string(env))
        }
      } else {
178
179
180
181
182
        if (cur_iter - best_iter[i] >= stopping_rounds) {
          if (!is.null(env$model)) { env$model$best_iter <- best_iter[i] }
          if (isTRUE(verbose)) {
            cat("Early stopping, best iteration is:", "\n")
            cat(best_msg[[i]], "\n")
Guolin Ke's avatar
Guolin Ke committed
183
          }
184
          env$best_iter      <- best_iter[i]
Guolin Ke's avatar
Guolin Ke committed
185
186
187
188
189
190
191
          env$met_early_stop <- TRUE
        }
      }
    }
  }
  attr(callback, 'call') <- match.call()
  attr(callback, 'name') <- 'cb.early.stop'
192
  callback
Guolin Ke's avatar
Guolin Ke committed
193
194
195
}

# Extract callback names from the list of callbacks
196
callback.names <- function(cb_list) { unlist(lapply(cb_list, attr, "name")) }
Guolin Ke's avatar
Guolin Ke committed
197
198
199
200
201
202
203

add.cb <- function(cb_list, cb) {
  cb_list <- c(cb_list, cb)
  names(cb_list) <- callback.names(cb_list)
  if ('cb.early.stop' %in% names(cb_list)) {
    cb_list <- c(cb_list, cb_list['cb.early.stop'])
    # this removes only the first one
204
    cb_list['cb.early.stop'] <- NULL
Guolin Ke's avatar
Guolin Ke committed
205
206
207
208
209
210
211
212
  }
  cb_list
}

categorize.callbacks <- function(cb_list) {
  list(
    pre_iter = Filter(function(x) {
        pre <- attr(x, 'is_pre_iteration')
213
        !is.null(pre) && pre
Guolin Ke's avatar
Guolin Ke committed
214
215
216
217
218
219
220
      }, cb_list),
    post_iter = Filter(function(x) {
        pre <- attr(x, 'is_pre_iteration')
        is.null(pre) || !pre
      }, cb_list)
  )
}