callback.R 10.4 KB
Newer Older
1
2
3
4
5
6
7
8
# constants that control naming in lists
.EVAL_KEY <- function() {
  return("eval")
}
.EVAL_ERR_KEY <- function() {
  return("eval_err")
}

James Lamb's avatar
James Lamb committed
9
10
#' @importFrom R6 R6Class
CB_ENV <- R6::R6Class(
Guolin Ke's avatar
Guolin Ke committed
11
  "lgb.cb_env",
12
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
13
  public = list(
14
15
    model = NULL,
    iteration = NULL,
16
    begin_iteration = NULL,
17
18
19
    end_iteration = NULL,
    eval_list = list(),
    eval_err_list = list(),
20
    best_iter = -1L,
21
    best_score = NA,
22
    met_early_stop = FALSE
Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
  )
)

cb.reset.parameters <- function(new_params) {
27

28
  if (!identical(class(new_params), "list")) {
29
30
    stop(sQuote("new_params"), " must be a list")
  }
31

32
  # Deparse parameter list
33
  pnames  <- gsub("\\.", "_", names(new_params))
Guolin Ke's avatar
Guolin Ke committed
34
  nrounds <- NULL
35

36
  # Run some checks in the beginning
Guolin Ke's avatar
Guolin Ke committed
37
  init <- function(env) {
38

39
    # Check for model environment
40
41
42
    if (is.null(env$model)) {
      stop("Env should have a ", sQuote("model"))
    }
43

44
45
46
    # Store boosting rounds
    nrounds <<- env$end_iteration - env$begin_iteration + 1L

47
    # Check parameter names
Guolin Ke's avatar
Guolin Ke committed
48
    for (n in pnames) {
49

50
      # Set name
Guolin Ke's avatar
Guolin Ke committed
51
      p <- new_params[[n]]
52

53
      # Check if function for parameter
Guolin Ke's avatar
Guolin Ke committed
54
      if (is.function(p)) {
55

56
        # Check if requires at least two arguments
57
        if (length(formals(p)) != 2L) {
58
          stop("Parameter ", sQuote(n), " is a function but not of two arguments")
59
        }
60

61
        # Check if numeric or character
Guolin Ke's avatar
Guolin Ke committed
62
      } else if (is.numeric(p) || is.character(p)) {
63

64
65
        # Check if length is matching
        if (length(p) != nrounds) {
66
          stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds"))
67
        }
68

Guolin Ke's avatar
Guolin Ke committed
69
      } else {
70

71
        stop("Parameter ", sQuote(n), " is not a function or a vector")
72

Guolin Ke's avatar
Guolin Ke committed
73
      }
74

Guolin Ke's avatar
Guolin Ke committed
75
    }
76

77
78
    return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
79
  }
80

Guolin Ke's avatar
Guolin Ke committed
81
  callback <- function(env) {
82

83
84
    # Check if rounds is null
    if (is.null(nrounds)) {
85
      init(env = env)
86
    }
87

88
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
89
    i <- env$iteration - env$begin_iteration
90

91
    # Apply list on parameters
Guolin Ke's avatar
Guolin Ke committed
92
    pars <- lapply(new_params, function(p) {
93
94
95
      if (is.function(p)) {
        return(p(i, nrounds))
      }
Guolin Ke's avatar
Guolin Ke committed
96
97
      p[i]
    })
98

99
    if (!is.null(env$model)) {
100
      return(env$model$reset_parameter(params = pars))
101
    }
102

103
104
    return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
105
  }
106

107
108
109
  attr(callback, "call") <- match.call()
  attr(callback, "is_pre_iteration") <- TRUE
  attr(callback, "name") <- "cb.reset.parameters"
110
  return(callback)
Guolin Ke's avatar
Guolin Ke committed
111
112
113
}

# Format the evaluation metric string
114
format.eval.string <- function(eval_res, eval_err) {
115

116
  # Check for empty evaluation string
117
  if (is.null(eval_res) || length(eval_res) == 0L) {
118
119
    stop("no evaluation results")
  }
120

121
  # Check for empty evaluation error
Guolin Ke's avatar
Guolin Ke committed
122
  if (!is.null(eval_err)) {
123
    return(sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err))
Guolin Ke's avatar
Guolin Ke committed
124
  } else {
125
    return(sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value))
Guolin Ke's avatar
Guolin Ke committed
126
  }
127

Guolin Ke's avatar
Guolin Ke committed
128
129
}

130
merge.eval.string <- function(env) {
131

132
  # Check length of evaluation list
133
  if (length(env$eval_list) <= 0L) {
134
135
    return("")
  }
136

137
138
  # Get evaluation
  msg <- list(sprintf("[%d]:", env$iteration))
139

140
  # Set if evaluation error
141
  is_eval_err <- length(env$eval_err_list) > 0L
142

143
  # Loop through evaluation list
144
  for (j in seq_along(env$eval_list)) {
145

146
    # Store evaluation error
Guolin Ke's avatar
Guolin Ke committed
147
    eval_err <- NULL
148
    if (isTRUE(is_eval_err)) {
149
150
      eval_err <- env$eval_err_list[[j]]
    }
151

152
    # Set error message
153
    msg <- c(msg, format.eval.string(eval_res = env$eval_list[[j]], eval_err = eval_err))
154

Guolin Ke's avatar
Guolin Ke committed
155
  }
156

157
  return(paste0(msg, collapse = "  "))
158

Guolin Ke's avatar
Guolin Ke committed
159
160
}

161
cb.print.evaluation <- function(period) {
162

163
  # Create callback
164
  callback <- function(env) {
165

166
    # Check if period is at least 1 or more
167
    if (period > 0L) {
168

169
      # Store iteration
Guolin Ke's avatar
Guolin Ke committed
170
      i <- env$iteration
171

172
      # Check if iteration matches moduo
173
      if ((i - 1L) %% period == 0L || is.element(i, c(env$begin_iteration, env$end_iteration))) {
174

175
        # Merge evaluation string
176
        msg <- merge.eval.string(env = env)
177

178
        # Check if message is existing
179
        if (nchar(msg) > 0L) {
180
          print(merge.eval.string(env = env))
181
        }
182

Guolin Ke's avatar
Guolin Ke committed
183
      }
184

Guolin Ke's avatar
Guolin Ke committed
185
    }
186

187
188
    return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
189
  }
190

191
192
193
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.print.evaluation"
194

195
  return(callback)
196

Guolin Ke's avatar
Guolin Ke committed
197
198
199
}

cb.record.evaluation <- function() {
200

201
  # Create callback
202
  callback <- function(env) {
203

204
    if (length(env$eval_list) <= 0L) {
205
206
      return()
    }
207

208
    # Set if evaluation error
209
    is_eval_err <- length(env$eval_err_list) > 0L
210

211
    # Check length of recorded evaluation
212
    if (length(env$model$record_evals) == 0L) {
213

214
      # Loop through each evaluation list element
215
      for (j in seq_along(env$eval_list)) {
216

217
        # Store names
Guolin Ke's avatar
Guolin Ke committed
218
        data_name <- env$eval_list[[j]]$data_name
219
        name <- env$eval_list[[j]]$name
Guolin Ke's avatar
Guolin Ke committed
220
        env$model$record_evals$start_iter <- env$begin_iteration
221

222
        # Check if evaluation record exists
223
        if (is.null(env$model$record_evals[[data_name]])) {
Guolin Ke's avatar
Guolin Ke committed
224
225
          env$model$record_evals[[data_name]] <- list()
        }
226

227
228
        # Create dummy lists
        env$model$record_evals[[data_name]][[name]] <- list()
229
230
        env$model$record_evals[[data_name]][[name]][[.EVAL_KEY()]] <- list()
        env$model$record_evals[[data_name]][[name]][[.EVAL_ERR_KEY()]] <- list()
231

Guolin Ke's avatar
Guolin Ke committed
232
      }
233

Guolin Ke's avatar
Guolin Ke committed
234
    }
235

236
    # Loop through each evaluation list element
237
    for (j in seq_along(env$eval_list)) {
238

239
      # Get evaluation data
Guolin Ke's avatar
Guolin Ke committed
240
241
      eval_res <- env$eval_list[[j]]
      eval_err <- NULL
242
      if (isTRUE(is_eval_err)) {
243
244
        eval_err <- env$eval_err_list[[j]]
      }
245

246
      # Store names
Guolin Ke's avatar
Guolin Ke committed
247
      data_name <- eval_res$data_name
248
      name <- eval_res$name
249

250
      # Store evaluation data
251
252
      env$model$record_evals[[data_name]][[name]][[.EVAL_KEY()]] <- c(
        env$model$record_evals[[data_name]][[name]][[.EVAL_KEY()]]
253
254
        , eval_res$value
      )
255
256
      env$model$record_evals[[data_name]][[name]][[.EVAL_ERR_KEY()]] <- c(
        env$model$record_evals[[data_name]][[name]][[.EVAL_ERR_KEY()]]
257
258
        , eval_err
      )
259

Guolin Ke's avatar
Guolin Ke committed
260
    }
261

262
263
    return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
264
  }
265

266
267
268
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.record.evaluation"
269

270
  return(callback)
271

Guolin Ke's avatar
Guolin Ke committed
272
273
}

274
cb.early.stop <- function(stopping_rounds, first_metric_only, verbose) {
275

Guolin Ke's avatar
Guolin Ke committed
276
  factor_to_bigger_better <- NULL
277
278
279
280
  best_iter <- NULL
  best_score <- NULL
  best_msg <- NULL
  eval_len <- NULL
281

282
  # Initialization function
Guolin Ke's avatar
Guolin Ke committed
283
  init <- function(env) {
284

285
    # Early stopping cannot work without metrics
286
    if (length(env$eval_list) == 0L) {
Guolin Ke's avatar
Guolin Ke committed
287
      stop("For early stopping, valids must have at least one element")
288
    }
289

290
291
292
    # Store evaluation length
    eval_len <<- length(env$eval_list)

293
    # Check if verbose or not
294
    if (isTRUE(verbose)) {
295
296
297
298
299
300
      msg <- paste0(
        "Will train until there is no improvement in "
        , stopping_rounds
        , " rounds."
      )
      print(msg)
301
    }
302

303
    # Internally treat everything as a maximization task
304
    factor_to_bigger_better <<- rep.int(1.0, eval_len)
305
    best_iter <<- rep.int(-1L, eval_len)
306
    best_score <<- rep.int(-Inf, eval_len)
307
    best_msg <<- list()
308

309
    # Loop through evaluation elements
310
    for (i in seq_len(eval_len)) {
311

312
      # Prepend message
Guolin Ke's avatar
Guolin Ke committed
313
      best_msg <<- c(best_msg, "")
314

315
316
      # Internally treat everything as a maximization task
      if (!isTRUE(env$eval_list[[i]]$higher_better)) {
Guolin Ke's avatar
Guolin Ke committed
317
318
        factor_to_bigger_better[i] <<- -1.0
      }
319

Guolin Ke's avatar
Guolin Ke committed
320
    }
321

322
323
    return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
324
  }
325

326
  # Create callback
327
  callback <- function(env) {
328

329
330
    # Check for empty evaluation
    if (is.null(eval_len)) {
331
      init(env = env)
332
    }
333

334
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
335
    cur_iter <- env$iteration
336

337
338
339
340
341
342
343
344
    # By default, any metric can trigger early stopping. This can be disabled
    # with 'first_metric_only = TRUE'
    if (isTRUE(first_metric_only)) {
      evals_to_check <- 1L
    } else {
      evals_to_check <- seq_len(eval_len)
    }

345
    # Loop through evaluation
346
    for (i in evals_to_check) {
347

348
      # Store score
Guolin Ke's avatar
Guolin Ke committed
349
      score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
350

351
352
        # Check if score is better
        if (score > best_score[i]) {
353

354
355
356
          # Store new scores
          best_score[i] <<- score
          best_iter[i] <<- cur_iter
357

358
359
          # Prepare to print if verbose
          if (verbose) {
360
            best_msg[[i]] <<- as.character(merge.eval.string(env = env))
361
          }
362

363
        } else {
364

365
366
          # Check if early stopping is required
          if (cur_iter - best_iter[i] >= stopping_rounds) {
367

368
369
370
371
372
            # Check if model is not null
            if (!is.null(env$model)) {
              env$model$best_score <- best_score[i]
              env$model$best_iter <- best_iter[i]
            }
373

374
375
            # Print message if verbose
            if (isTRUE(verbose)) {
376

377
              print(paste0("Early stopping, best iteration is: ", best_msg[[i]]))
378

379
            }
380

381
382
383
            # Store best iteration and stop
            env$best_iter <- best_iter[i]
            env$met_early_stop <- TRUE
Guolin Ke's avatar
Guolin Ke committed
384
          }
385

Guolin Ke's avatar
Guolin Ke committed
386
        }
387

Guolin Ke's avatar
Guolin Ke committed
388
      if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) {
389
390
391
392
393
        # Check if model is not null
        if (!is.null(env$model)) {
          env$model$best_score <- best_score[i]
          env$model$best_iter <- best_iter[i]
        }
394

395
396
        # Print message if verbose
        if (isTRUE(verbose)) {
397
          print(paste0("Did not meet early stopping, best iteration is: ", best_msg[[i]]))
398
        }
399

400
401
402
403
        # Store best iteration and stop
        env$best_iter <- best_iter[i]
        env$met_early_stop <- TRUE
      }
Guolin Ke's avatar
Guolin Ke committed
404
    }
405
406
407

    return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
408
  }
409

410
411
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.early.stop"
412

413
  return(callback)
414

Guolin Ke's avatar
Guolin Ke committed
415
416
417
}

# Extract callback names from the list of callbacks
418
callback.names <- function(cb_list) {
419
  return(unlist(lapply(cb_list, attr, "name")))
420
}
Guolin Ke's avatar
Guolin Ke committed
421
422

add.cb <- function(cb_list, cb) {
423

424
  # Combine two elements
Guolin Ke's avatar
Guolin Ke committed
425
  cb_list <- c(cb_list, cb)
426

427
  # Set names of elements
428
  names(cb_list) <- callback.names(cb_list = cb_list)
429

430
431
  # Check for existence
  if ("cb.early.stop" %in% names(cb_list)) {
432

433
434
    # Concatenate existing elements
    cb_list <- c(cb_list, cb_list["cb.early.stop"])
435

436
437
    # Remove only the first one
    cb_list["cb.early.stop"] <- NULL
438

Guolin Ke's avatar
Guolin Ke committed
439
  }
440

441
  # Return element
442
  return(cb_list)
443

Guolin Ke's avatar
Guolin Ke committed
444
445
446
}

categorize.callbacks <- function(cb_list) {
447

448
  # Check for pre-iteration or post-iteration
449
450
451
452
453
454
455
456
457
458
459
  return(
    list(
      pre_iter = Filter(function(x) {
        pre <- attr(x, "is_pre_iteration")
        !is.null(pre) && pre
      }, cb_list),
      post_iter = Filter(function(x) {
        pre <- attr(x, "is_pre_iteration")
        is.null(pre) || !pre
      }, cb_list)
    )
Guolin Ke's avatar
Guolin Ke committed
460
  )
461

Guolin Ke's avatar
Guolin Ke committed
462
}