callback.R 10.2 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
#' @importFrom R6 R6Class
CB_ENV <- R6::R6Class(
Guolin Ke's avatar
Guolin Ke committed
3
  "lgb.cb_env",
4
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
5
  public = list(
6
7
    model = NULL,
    iteration = NULL,
8
    begin_iteration = NULL,
9
10
11
    end_iteration = NULL,
    eval_list = list(),
    eval_err_list = list(),
12
    best_iter = -1L,
13
    best_score = NA,
14
    met_early_stop = FALSE
Guolin Ke's avatar
Guolin Ke committed
15
16
17
18
  )
)

cb.reset.parameters <- function(new_params) {
19

20
21
22
23
  # Check for parameter list
  if (!is.list(new_params)) {
    stop(sQuote("new_params"), " must be a list")
  }
24

25
  # Deparse parameter list
26
  pnames  <- gsub("\\.", "_", names(new_params))
Guolin Ke's avatar
Guolin Ke committed
27
  nrounds <- NULL
28

29
  # Run some checks in the beginning
Guolin Ke's avatar
Guolin Ke committed
30
  init <- function(env) {
31

32
    # Check for model environment
33
34
35
    if (is.null(env$model)) {
      stop("Env should have a ", sQuote("model"))
    }
36

Guolin Ke's avatar
Guolin Ke committed
37
38
    # Some parameters are not allowed to be changed,
    # since changing them would simply wreck some chaos
39
40
41
42
43
    not_allowed <- c(
      .PARAMETER_ALIASES()[["num_class"]]
      , .PARAMETER_ALIASES()[["metric"]]
      , .PARAMETER_ALIASES()[["boosting"]]
    )
44
    if (any(pnames %in% not_allowed)) {
45
46
47
48
49
      stop(
        "Parameters "
        , paste0(pnames[pnames %in% not_allowed], collapse = ", ")
        , " cannot be changed during boosting"
      )
50
    }
51

52
53
54
    # Store boosting rounds
    nrounds <<- env$end_iteration - env$begin_iteration + 1L

55
    # Check parameter names
Guolin Ke's avatar
Guolin Ke committed
56
    for (n in pnames) {
57

58
      # Set name
Guolin Ke's avatar
Guolin Ke committed
59
      p <- new_params[[n]]
60

61
      # Check if function for parameter
Guolin Ke's avatar
Guolin Ke committed
62
      if (is.function(p)) {
63

64
        # Check if requires at least two arguments
65
        if (length(formals(p)) != 2L) {
66
          stop("Parameter ", sQuote(n), " is a function but not of two arguments")
67
        }
68

69
        # Check if numeric or character
Guolin Ke's avatar
Guolin Ke committed
70
      } else if (is.numeric(p) || is.character(p)) {
71

72
73
        # Check if length is matching
        if (length(p) != nrounds) {
74
          stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds"))
75
        }
76

Guolin Ke's avatar
Guolin Ke committed
77
      } else {
78

79
        stop("Parameter ", sQuote(n), " is not a function or a vector")
80

Guolin Ke's avatar
Guolin Ke committed
81
      }
82

Guolin Ke's avatar
Guolin Ke committed
83
    }
84

Guolin Ke's avatar
Guolin Ke committed
85
  }
86

Guolin Ke's avatar
Guolin Ke committed
87
  callback <- function(env) {
88

89
90
91
92
    # Check if rounds is null
    if (is.null(nrounds)) {
      init(env)
    }
93

94
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
95
    i <- env$iteration - env$begin_iteration
96

97
    # Apply list on parameters
Guolin Ke's avatar
Guolin Ke committed
98
    pars <- lapply(new_params, function(p) {
99
100
101
      if (is.function(p)) {
        return(p(i, nrounds))
      }
Guolin Ke's avatar
Guolin Ke committed
102
103
      p[i]
    })
104

105
106
107
    if (!is.null(env$model)) {
      env$model$reset_parameter(pars)
    }
108

Guolin Ke's avatar
Guolin Ke committed
109
  }
110

111
112
113
  attr(callback, "call") <- match.call()
  attr(callback, "is_pre_iteration") <- TRUE
  attr(callback, "name") <- "cb.reset.parameters"
114
  callback
Guolin Ke's avatar
Guolin Ke committed
115
116
117
}

# Format the evaluation metric string
118
format.eval.string <- function(eval_res, eval_err = NULL) {
119

120
  # Check for empty evaluation string
121
  if (is.null(eval_res) || length(eval_res) == 0L) {
122
123
    stop("no evaluation results")
  }
124

125
  # Check for empty evaluation error
Guolin Ke's avatar
Guolin Ke committed
126
  if (!is.null(eval_err)) {
127
    sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err)
Guolin Ke's avatar
Guolin Ke committed
128
  } else {
129
    sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value)
Guolin Ke's avatar
Guolin Ke committed
130
  }
131

Guolin Ke's avatar
Guolin Ke committed
132
133
}

134
merge.eval.string <- function(env) {
135

136
  # Check length of evaluation list
137
  if (length(env$eval_list) <= 0L) {
138
139
    return("")
  }
140

141
142
  # Get evaluation
  msg <- list(sprintf("[%d]:", env$iteration))
143

144
  # Set if evaluation error
145
  is_eval_err <- length(env$eval_err_list) > 0L
146

147
  # Loop through evaluation list
148
  for (j in seq_along(env$eval_list)) {
149

150
    # Store evaluation error
Guolin Ke's avatar
Guolin Ke committed
151
    eval_err <- NULL
152
153
154
    if (is_eval_err) {
      eval_err <- env$eval_err_list[[j]]
    }
155

156
    # Set error message
157
    msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err))
158

Guolin Ke's avatar
Guolin Ke committed
159
  }
160

161
162
  # Return tabulated separated message
  paste0(msg, collapse = "\t")
163

Guolin Ke's avatar
Guolin Ke committed
164
165
}

166
cb.print.evaluation <- function(period = 1L) {
167

168
  # Create callback
169
  callback <- function(env) {
170

171
    # Check if period is at least 1 or more
172
    if (period > 0L) {
173

174
      # Store iteration
Guolin Ke's avatar
Guolin Ke committed
175
      i <- env$iteration
176

177
      # Check if iteration matches moduo
178
      if ((i - 1L) %% period == 0L || is.element(i, c(env$begin_iteration, env$end_iteration))) {
179

180
        # Merge evaluation string
181
        msg <- merge.eval.string(env)
182

183
        # Check if message is existing
184
        if (nchar(msg) > 0L) {
185
186
          cat(merge.eval.string(env), "\n")
        }
187

Guolin Ke's avatar
Guolin Ke committed
188
      }
189

Guolin Ke's avatar
Guolin Ke committed
190
    }
191

Guolin Ke's avatar
Guolin Ke committed
192
  }
193

194
195
196
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.print.evaluation"
197

198
  # Return callback
199
  callback
200

Guolin Ke's avatar
Guolin Ke committed
201
202
203
}

cb.record.evaluation <- function() {
204

205
  # Create callback
206
  callback <- function(env) {
207

208
    # Return empty if empty evaluation list
209
    if (length(env$eval_list) <= 0L) {
210
211
      return()
    }
212

213
    # Set if evaluation error
214
    is_eval_err <- length(env$eval_err_list) > 0L
215

216
    # Check length of recorded evaluation
217
    if (length(env$model$record_evals) == 0L) {
218

219
      # Loop through each evaluation list element
220
      for (j in seq_along(env$eval_list)) {
221

222
        # Store names
Guolin Ke's avatar
Guolin Ke committed
223
        data_name <- env$eval_list[[j]]$data_name
224
        name <- env$eval_list[[j]]$name
Guolin Ke's avatar
Guolin Ke committed
225
        env$model$record_evals$start_iter <- env$begin_iteration
226

227
        # Check if evaluation record exists
228
        if (is.null(env$model$record_evals[[data_name]])) {
Guolin Ke's avatar
Guolin Ke committed
229
230
          env$model$record_evals[[data_name]] <- list()
        }
231

232
233
234
        # Create dummy lists
        env$model$record_evals[[data_name]][[name]] <- list()
        env$model$record_evals[[data_name]][[name]]$eval <- list()
Guolin Ke's avatar
Guolin Ke committed
235
        env$model$record_evals[[data_name]][[name]]$eval_err <- list()
236

Guolin Ke's avatar
Guolin Ke committed
237
      }
238

Guolin Ke's avatar
Guolin Ke committed
239
    }
240

241
    # Loop through each evaluation list element
242
    for (j in seq_along(env$eval_list)) {
243

244
      # Get evaluation data
Guolin Ke's avatar
Guolin Ke committed
245
246
      eval_res <- env$eval_list[[j]]
      eval_err <- NULL
247
248
249
      if (is_eval_err) {
        eval_err <- env$eval_err_list[[j]]
      }
250

251
      # Store names
Guolin Ke's avatar
Guolin Ke committed
252
      data_name <- eval_res$data_name
253
      name <- eval_res$name
254

255
      # Store evaluation data
256
257
258
259
260
261
262
263
      env$model$record_evals[[data_name]][[name]]$eval <- c(
        env$model$record_evals[[data_name]][[name]]$eval
        , eval_res$value
      )
      env$model$record_evals[[data_name]][[name]]$eval_err <- c(
        env$model$record_evals[[data_name]][[name]]$eval_err
        , eval_err
      )
264

Guolin Ke's avatar
Guolin Ke committed
265
    }
266

Guolin Ke's avatar
Guolin Ke committed
267
  }
268

269
270
271
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.record.evaluation"
272

273
  # Return callback
274
  callback
275

Guolin Ke's avatar
Guolin Ke committed
276
277
}

278
cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
279

280
  # Initialize variables
Guolin Ke's avatar
Guolin Ke committed
281
  factor_to_bigger_better <- NULL
282
283
284
285
  best_iter <- NULL
  best_score <- NULL
  best_msg <- NULL
  eval_len <- NULL
286

287
  # Initialization function
Guolin Ke's avatar
Guolin Ke committed
288
  init <- function(env) {
289

290
    # Early stopping cannot work without metrics
291
    if (length(env$eval_list) == 0L) {
Guolin Ke's avatar
Guolin Ke committed
292
      stop("For early stopping, valids must have at least one element")
293
    }
294

295
296
297
    # Store evaluation length
    eval_len <<- length(env$eval_list)

298
    # Check if verbose or not
299
    if (isTRUE(verbose)) {
300
      cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "")
301
    }
302

303
    # Maximization or minimization task
304
    factor_to_bigger_better <<- rep.int(1.0, eval_len)
305
    best_iter <<- rep.int(-1L, eval_len)
306
    best_score <<- rep.int(-Inf, eval_len)
307
    best_msg <<- list()
308

309
    # Loop through evaluation elements
310
    for (i in seq_len(eval_len)) {
311

312
      # Prepend message
Guolin Ke's avatar
Guolin Ke committed
313
      best_msg <<- c(best_msg, "")
314

315
      # Check if maximization or minimization
316
      if (!env$eval_list[[i]]$higher_better) {
Guolin Ke's avatar
Guolin Ke committed
317
318
        factor_to_bigger_better[i] <<- -1.0
      }
319

Guolin Ke's avatar
Guolin Ke committed
320
    }
321

Guolin Ke's avatar
Guolin Ke committed
322
  }
323

324
  # Create callback
Guolin Ke's avatar
Guolin Ke committed
325
  callback <- function(env, finalize = FALSE) {
326

327
328
329
330
    # Check for empty evaluation
    if (is.null(eval_len)) {
      init(env)
    }
331

332
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
333
    cur_iter <- env$iteration
334

335
    # Loop through evaluation
336
    for (i in seq_len(eval_len)) {
337

338
      # Store score
Guolin Ke's avatar
Guolin Ke committed
339
      score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
340

341
342
        # Check if score is better
        if (score > best_score[i]) {
343

344
345
346
          # Store new scores
          best_score[i] <<- score
          best_iter[i] <<- cur_iter
347

348
349
350
          # Prepare to print if verbose
          if (verbose) {
            best_msg[[i]] <<- as.character(merge.eval.string(env))
351
          }
352

353
        } else {
354

355
356
          # Check if early stopping is required
          if (cur_iter - best_iter[i] >= stopping_rounds) {
357

358
359
360
361
362
            # Check if model is not null
            if (!is.null(env$model)) {
              env$model$best_score <- best_score[i]
              env$model$best_iter <- best_iter[i]
            }
363

364
365
            # Print message if verbose
            if (isTRUE(verbose)) {
366

367
368
              cat("Early stopping, best iteration is:", "\n")
              cat(best_msg[[i]], "\n")
369

370
            }
371

372
373
374
            # Store best iteration and stop
            env$best_iter <- best_iter[i]
            env$met_early_stop <- TRUE
Guolin Ke's avatar
Guolin Ke committed
375
          }
376

Guolin Ke's avatar
Guolin Ke committed
377
        }
378

Guolin Ke's avatar
Guolin Ke committed
379
      if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) {
380
381
382
383
384
        # Check if model is not null
        if (!is.null(env$model)) {
          env$model$best_score <- best_score[i]
          env$model$best_iter <- best_iter[i]
        }
385

386
387
388
389
390
        # Print message if verbose
        if (isTRUE(verbose)) {
          cat("Did not meet early stopping, best iteration is:", "\n")
          cat(best_msg[[i]], "\n")
        }
391

392
393
394
395
        # Store best iteration and stop
        env$best_iter <- best_iter[i]
        env$met_early_stop <- TRUE
      }
Guolin Ke's avatar
Guolin Ke committed
396
397
    }
  }
398

399
400
401
  # Set attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.early.stop"
402

403
  # Return callback
404
  callback
405

Guolin Ke's avatar
Guolin Ke committed
406
407
408
}

# Extract callback names from the list of callbacks
409
410
411
callback.names <- function(cb_list) {
  unlist(lapply(cb_list, attr, "name"))
}
Guolin Ke's avatar
Guolin Ke committed
412
413

add.cb <- function(cb_list, cb) {
414

415
  # Combine two elements
Guolin Ke's avatar
Guolin Ke committed
416
  cb_list <- c(cb_list, cb)
417

418
  # Set names of elements
Guolin Ke's avatar
Guolin Ke committed
419
  names(cb_list) <- callback.names(cb_list)
420

421
422
  # Check for existence
  if ("cb.early.stop" %in% names(cb_list)) {
423

424
425
    # Concatenate existing elements
    cb_list <- c(cb_list, cb_list["cb.early.stop"])
426

427
428
    # Remove only the first one
    cb_list["cb.early.stop"] <- NULL
429

Guolin Ke's avatar
Guolin Ke committed
430
  }
431

432
  # Return element
Guolin Ke's avatar
Guolin Ke committed
433
  cb_list
434

Guolin Ke's avatar
Guolin Ke committed
435
436
437
}

categorize.callbacks <- function(cb_list) {
438

439
  # Check for pre-iteration or post-iteration
Guolin Ke's avatar
Guolin Ke committed
440
441
  list(
    pre_iter = Filter(function(x) {
442
443
444
      pre <- attr(x, "is_pre_iteration")
      !is.null(pre) && pre
    }, cb_list),
Guolin Ke's avatar
Guolin Ke committed
445
    post_iter = Filter(function(x) {
446
447
448
      pre <- attr(x, "is_pre_iteration")
      is.null(pre) || !pre
    }, cb_list)
Guolin Ke's avatar
Guolin Ke committed
449
  )
450

Guolin Ke's avatar
Guolin Ke committed
451
}