callback.R 10.2 KB
Newer Older
1
2
3
4
5
6
7
8
# constants that control naming in lists
.EVAL_KEY <- function() {
  return("eval")
}
.EVAL_ERR_KEY <- function() {
  return("eval_err")
}

James Lamb's avatar
James Lamb committed
9
10
#' @importFrom R6 R6Class
CB_ENV <- R6::R6Class(
Guolin Ke's avatar
Guolin Ke committed
11
  "lgb.cb_env",
12
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
13
  public = list(
14
15
    model = NULL,
    iteration = NULL,
16
    begin_iteration = NULL,
17
18
19
    end_iteration = NULL,
    eval_list = list(),
    eval_err_list = list(),
20
    best_iter = -1L,
21
    best_score = NA,
22
    met_early_stop = FALSE
Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
  )
)

cb.reset.parameters <- function(new_params) {
27

28
  if (!identical(class(new_params), "list")) {
29
30
    stop(sQuote("new_params"), " must be a list")
  }
31

32
  # Deparse parameter list
33
  pnames  <- gsub("\\.", "_", names(new_params))
Guolin Ke's avatar
Guolin Ke committed
34
  nrounds <- NULL
35

36
  # Run some checks in the beginning
Guolin Ke's avatar
Guolin Ke committed
37
  init <- function(env) {
38

39
    # Check for model environment
40
41
42
    if (is.null(env$model)) {
      stop("Env should have a ", sQuote("model"))
    }
43

44
45
46
    # Store boosting rounds
    nrounds <<- env$end_iteration - env$begin_iteration + 1L

47
    # Check parameter names
Guolin Ke's avatar
Guolin Ke committed
48
    for (n in pnames) {
49

50
      # Set name
Guolin Ke's avatar
Guolin Ke committed
51
      p <- new_params[[n]]
52

53
      # Check if function for parameter
Guolin Ke's avatar
Guolin Ke committed
54
      if (is.function(p)) {
55

56
        # Check if requires at least two arguments
57
        if (length(formals(p)) != 2L) {
58
          stop("Parameter ", sQuote(n), " is a function but not of two arguments")
59
        }
60

61
        # Check if numeric or character
Guolin Ke's avatar
Guolin Ke committed
62
      } else if (is.numeric(p) || is.character(p)) {
63

64
65
        # Check if length is matching
        if (length(p) != nrounds) {
66
          stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds"))
67
        }
68

Guolin Ke's avatar
Guolin Ke committed
69
      } else {
70

71
        stop("Parameter ", sQuote(n), " is not a function or a vector")
72

Guolin Ke's avatar
Guolin Ke committed
73
      }
74

Guolin Ke's avatar
Guolin Ke committed
75
    }
76

Guolin Ke's avatar
Guolin Ke committed
77
  }
78

Guolin Ke's avatar
Guolin Ke committed
79
  callback <- function(env) {
80

81
82
    # Check if rounds is null
    if (is.null(nrounds)) {
83
      init(env = env)
84
    }
85

86
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
87
    i <- env$iteration - env$begin_iteration
88

89
    # Apply list on parameters
Guolin Ke's avatar
Guolin Ke committed
90
    pars <- lapply(new_params, function(p) {
91
92
93
      if (is.function(p)) {
        return(p(i, nrounds))
      }
Guolin Ke's avatar
Guolin Ke committed
94
95
      p[i]
    })
96

97
    if (!is.null(env$model)) {
98
      env$model$reset_parameter(params = pars)
99
    }
100

Guolin Ke's avatar
Guolin Ke committed
101
  }
102

103
104
105
  attr(callback, "call") <- match.call()
  attr(callback, "is_pre_iteration") <- TRUE
  attr(callback, "name") <- "cb.reset.parameters"
106
  callback
Guolin Ke's avatar
Guolin Ke committed
107
108
109
}

# Format the evaluation metric string
110
format.eval.string <- function(eval_res, eval_err = NULL) {
111

112
  # Check for empty evaluation string
113
  if (is.null(eval_res) || length(eval_res) == 0L) {
114
115
    stop("no evaluation results")
  }
116

117
  # Check for empty evaluation error
Guolin Ke's avatar
Guolin Ke committed
118
  if (!is.null(eval_err)) {
119
    sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err)
Guolin Ke's avatar
Guolin Ke committed
120
  } else {
121
    sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value)
Guolin Ke's avatar
Guolin Ke committed
122
  }
123

Guolin Ke's avatar
Guolin Ke committed
124
125
}

126
merge.eval.string <- function(env) {
127

128
  # Check length of evaluation list
129
  if (length(env$eval_list) <= 0L) {
130
131
    return("")
  }
132

133
134
  # Get evaluation
  msg <- list(sprintf("[%d]:", env$iteration))
135

136
  # Set if evaluation error
137
  is_eval_err <- length(env$eval_err_list) > 0L
138

139
  # Loop through evaluation list
140
  for (j in seq_along(env$eval_list)) {
141

142
    # Store evaluation error
Guolin Ke's avatar
Guolin Ke committed
143
    eval_err <- NULL
144
    if (isTRUE(is_eval_err)) {
145
146
      eval_err <- env$eval_err_list[[j]]
    }
147

148
    # Set error message
149
    msg <- c(msg, format.eval.string(eval_res = env$eval_list[[j]], eval_err = eval_err))
150

Guolin Ke's avatar
Guolin Ke committed
151
  }
152

153
  paste0(msg, collapse = "  ")
154

Guolin Ke's avatar
Guolin Ke committed
155
156
}

157
cb.print.evaluation <- function(period = 1L) {
158

159
  # Create callback
160
  callback <- function(env) {
161

162
    # Check if period is at least 1 or more
163
    if (period > 0L) {
164

165
      # Store iteration
Guolin Ke's avatar
Guolin Ke committed
166
      i <- env$iteration
167

168
      # Check if iteration matches moduo
169
      if ((i - 1L) %% period == 0L || is.element(i, c(env$begin_iteration, env$end_iteration))) {
170

171
        # Merge evaluation string
172
        msg <- merge.eval.string(env = env)
173

174
        # Check if message is existing
175
        if (nchar(msg) > 0L) {
176
          print(merge.eval.string(env = env))
177
        }
178

Guolin Ke's avatar
Guolin Ke committed
179
      }
180

Guolin Ke's avatar
Guolin Ke committed
181
    }
182

Guolin Ke's avatar
Guolin Ke committed
183
  }
184

185
186
187
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.print.evaluation"
188

189
  callback
190

Guolin Ke's avatar
Guolin Ke committed
191
192
193
}

cb.record.evaluation <- function() {
194

195
  # Create callback
196
  callback <- function(env) {
197

198
    if (length(env$eval_list) <= 0L) {
199
200
      return()
    }
201

202
    # Set if evaluation error
203
    is_eval_err <- length(env$eval_err_list) > 0L
204

205
    # Check length of recorded evaluation
206
    if (length(env$model$record_evals) == 0L) {
207

208
      # Loop through each evaluation list element
209
      for (j in seq_along(env$eval_list)) {
210

211
        # Store names
Guolin Ke's avatar
Guolin Ke committed
212
        data_name <- env$eval_list[[j]]$data_name
213
        name <- env$eval_list[[j]]$name
Guolin Ke's avatar
Guolin Ke committed
214
        env$model$record_evals$start_iter <- env$begin_iteration
215

216
        # Check if evaluation record exists
217
        if (is.null(env$model$record_evals[[data_name]])) {
Guolin Ke's avatar
Guolin Ke committed
218
219
          env$model$record_evals[[data_name]] <- list()
        }
220

221
222
        # Create dummy lists
        env$model$record_evals[[data_name]][[name]] <- list()
223
224
        env$model$record_evals[[data_name]][[name]][[.EVAL_KEY()]] <- list()
        env$model$record_evals[[data_name]][[name]][[.EVAL_ERR_KEY()]] <- list()
225

Guolin Ke's avatar
Guolin Ke committed
226
      }
227

Guolin Ke's avatar
Guolin Ke committed
228
    }
229

230
    # Loop through each evaluation list element
231
    for (j in seq_along(env$eval_list)) {
232

233
      # Get evaluation data
Guolin Ke's avatar
Guolin Ke committed
234
235
      eval_res <- env$eval_list[[j]]
      eval_err <- NULL
236
      if (isTRUE(is_eval_err)) {
237
238
        eval_err <- env$eval_err_list[[j]]
      }
239

240
      # Store names
Guolin Ke's avatar
Guolin Ke committed
241
      data_name <- eval_res$data_name
242
      name <- eval_res$name
243

244
      # Store evaluation data
245
246
      env$model$record_evals[[data_name]][[name]][[.EVAL_KEY()]] <- c(
        env$model$record_evals[[data_name]][[name]][[.EVAL_KEY()]]
247
248
        , eval_res$value
      )
249
250
      env$model$record_evals[[data_name]][[name]][[.EVAL_ERR_KEY()]] <- c(
        env$model$record_evals[[data_name]][[name]][[.EVAL_ERR_KEY()]]
251
252
        , eval_err
      )
253

Guolin Ke's avatar
Guolin Ke committed
254
    }
255

Guolin Ke's avatar
Guolin Ke committed
256
  }
257

258
259
260
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.record.evaluation"
261

262
  callback
263

Guolin Ke's avatar
Guolin Ke committed
264
265
}

266
cb.early.stop <- function(stopping_rounds, first_metric_only = FALSE, verbose = TRUE) {
267

Guolin Ke's avatar
Guolin Ke committed
268
  factor_to_bigger_better <- NULL
269
270
271
272
  best_iter <- NULL
  best_score <- NULL
  best_msg <- NULL
  eval_len <- NULL
273

274
  # Initialization function
Guolin Ke's avatar
Guolin Ke committed
275
  init <- function(env) {
276

277
    # Early stopping cannot work without metrics
278
    if (length(env$eval_list) == 0L) {
Guolin Ke's avatar
Guolin Ke committed
279
      stop("For early stopping, valids must have at least one element")
280
    }
281

282
283
284
    # Store evaluation length
    eval_len <<- length(env$eval_list)

285
    # Check if verbose or not
286
    if (isTRUE(verbose)) {
287
288
289
290
291
292
      msg <- paste0(
        "Will train until there is no improvement in "
        , stopping_rounds
        , " rounds."
      )
      print(msg)
293
    }
294

295
    # Internally treat everything as a maximization task
296
    factor_to_bigger_better <<- rep.int(1.0, eval_len)
297
    best_iter <<- rep.int(-1L, eval_len)
298
    best_score <<- rep.int(-Inf, eval_len)
299
    best_msg <<- list()
300

301
    # Loop through evaluation elements
302
    for (i in seq_len(eval_len)) {
303

304
      # Prepend message
Guolin Ke's avatar
Guolin Ke committed
305
      best_msg <<- c(best_msg, "")
306

307
308
      # Internally treat everything as a maximization task
      if (!isTRUE(env$eval_list[[i]]$higher_better)) {
Guolin Ke's avatar
Guolin Ke committed
309
310
        factor_to_bigger_better[i] <<- -1.0
      }
311

Guolin Ke's avatar
Guolin Ke committed
312
    }
313

Guolin Ke's avatar
Guolin Ke committed
314
  }
315

316
  # Create callback
Guolin Ke's avatar
Guolin Ke committed
317
  callback <- function(env, finalize = FALSE) {
318

319
320
    # Check for empty evaluation
    if (is.null(eval_len)) {
321
      init(env = env)
322
    }
323

324
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
325
    cur_iter <- env$iteration
326

327
328
329
330
331
332
333
334
    # By default, any metric can trigger early stopping. This can be disabled
    # with 'first_metric_only = TRUE'
    if (isTRUE(first_metric_only)) {
      evals_to_check <- 1L
    } else {
      evals_to_check <- seq_len(eval_len)
    }

335
    # Loop through evaluation
336
    for (i in evals_to_check) {
337

338
      # Store score
Guolin Ke's avatar
Guolin Ke committed
339
      score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
340

341
342
        # Check if score is better
        if (score > best_score[i]) {
343

344
345
346
          # Store new scores
          best_score[i] <<- score
          best_iter[i] <<- cur_iter
347

348
349
          # Prepare to print if verbose
          if (verbose) {
350
            best_msg[[i]] <<- as.character(merge.eval.string(env = env))
351
          }
352

353
        } else {
354

355
356
          # Check if early stopping is required
          if (cur_iter - best_iter[i] >= stopping_rounds) {
357

358
359
360
361
362
            # Check if model is not null
            if (!is.null(env$model)) {
              env$model$best_score <- best_score[i]
              env$model$best_iter <- best_iter[i]
            }
363

364
365
            # Print message if verbose
            if (isTRUE(verbose)) {
366

367
              print(paste0("Early stopping, best iteration is: ", best_msg[[i]]))
368

369
            }
370

371
372
373
            # Store best iteration and stop
            env$best_iter <- best_iter[i]
            env$met_early_stop <- TRUE
Guolin Ke's avatar
Guolin Ke committed
374
          }
375

Guolin Ke's avatar
Guolin Ke committed
376
        }
377

Guolin Ke's avatar
Guolin Ke committed
378
      if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) {
379
380
381
382
383
        # Check if model is not null
        if (!is.null(env$model)) {
          env$model$best_score <- best_score[i]
          env$model$best_iter <- best_iter[i]
        }
384

385
386
        # Print message if verbose
        if (isTRUE(verbose)) {
387
          print(paste0("Did not meet early stopping, best iteration is: ", best_msg[[i]]))
388
        }
389

390
391
392
393
        # Store best iteration and stop
        env$best_iter <- best_iter[i]
        env$met_early_stop <- TRUE
      }
Guolin Ke's avatar
Guolin Ke committed
394
395
    }
  }
396

397
398
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.early.stop"
399

400
  callback
401

Guolin Ke's avatar
Guolin Ke committed
402
403
404
}

# Extract callback names from the list of callbacks
405
406
407
callback.names <- function(cb_list) {
  unlist(lapply(cb_list, attr, "name"))
}
Guolin Ke's avatar
Guolin Ke committed
408
409

add.cb <- function(cb_list, cb) {
410

411
  # Combine two elements
Guolin Ke's avatar
Guolin Ke committed
412
  cb_list <- c(cb_list, cb)
413

414
  # Set names of elements
415
  names(cb_list) <- callback.names(cb_list = cb_list)
416

417
418
  # Check for existence
  if ("cb.early.stop" %in% names(cb_list)) {
419

420
421
    # Concatenate existing elements
    cb_list <- c(cb_list, cb_list["cb.early.stop"])
422

423
424
    # Remove only the first one
    cb_list["cb.early.stop"] <- NULL
425

Guolin Ke's avatar
Guolin Ke committed
426
  }
427

428
  # Return element
Guolin Ke's avatar
Guolin Ke committed
429
  cb_list
430

Guolin Ke's avatar
Guolin Ke committed
431
432
433
}

categorize.callbacks <- function(cb_list) {
434

435
  # Check for pre-iteration or post-iteration
Guolin Ke's avatar
Guolin Ke committed
436
437
  list(
    pre_iter = Filter(function(x) {
438
439
440
      pre <- attr(x, "is_pre_iteration")
      !is.null(pre) && pre
    }, cb_list),
Guolin Ke's avatar
Guolin Ke committed
441
    post_iter = Filter(function(x) {
442
443
444
      pre <- attr(x, "is_pre_iteration")
      is.null(pre) || !pre
    }, cb_list)
Guolin Ke's avatar
Guolin Ke committed
445
  )
446

Guolin Ke's avatar
Guolin Ke committed
447
}