callback.R 10 KB
Newer Older
1
2
3
4
5
6
7
8
# constants that control naming in lists
.EVAL_KEY <- function() {
  return("eval")
}
.EVAL_ERR_KEY <- function() {
  return("eval_err")
}

James Lamb's avatar
James Lamb committed
9
10
#' @importFrom R6 R6Class
CB_ENV <- R6::R6Class(
Guolin Ke's avatar
Guolin Ke committed
11
  "lgb.cb_env",
12
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
13
  public = list(
14
15
    model = NULL,
    iteration = NULL,
16
    begin_iteration = NULL,
17
18
19
    end_iteration = NULL,
    eval_list = list(),
    eval_err_list = list(),
20
    best_iter = -1L,
21
    best_score = NA,
22
    met_early_stop = FALSE
Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
  )
)

cb.reset.parameters <- function(new_params) {
27

28
  # Check for parameter list
29
  if (!identical(class(new_params), "list")) {
30
31
    stop(sQuote("new_params"), " must be a list")
  }
32

33
  # Deparse parameter list
34
  pnames  <- gsub("\\.", "_", names(new_params))
Guolin Ke's avatar
Guolin Ke committed
35
  nrounds <- NULL
36

37
  # Run some checks in the beginning
Guolin Ke's avatar
Guolin Ke committed
38
  init <- function(env) {
39

40
    # Check for model environment
41
42
43
    if (is.null(env$model)) {
      stop("Env should have a ", sQuote("model"))
    }
44

45
46
47
    # Store boosting rounds
    nrounds <<- env$end_iteration - env$begin_iteration + 1L

48
    # Check parameter names
Guolin Ke's avatar
Guolin Ke committed
49
    for (n in pnames) {
50

51
      # Set name
Guolin Ke's avatar
Guolin Ke committed
52
      p <- new_params[[n]]
53

54
      # Check if function for parameter
Guolin Ke's avatar
Guolin Ke committed
55
      if (is.function(p)) {
56

57
        # Check if requires at least two arguments
58
        if (length(formals(p)) != 2L) {
59
          stop("Parameter ", sQuote(n), " is a function but not of two arguments")
60
        }
61

62
        # Check if numeric or character
Guolin Ke's avatar
Guolin Ke committed
63
      } else if (is.numeric(p) || is.character(p)) {
64

65
66
        # Check if length is matching
        if (length(p) != nrounds) {
67
          stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds"))
68
        }
69

Guolin Ke's avatar
Guolin Ke committed
70
      } else {
71

72
        stop("Parameter ", sQuote(n), " is not a function or a vector")
73

Guolin Ke's avatar
Guolin Ke committed
74
      }
75

Guolin Ke's avatar
Guolin Ke committed
76
    }
77

Guolin Ke's avatar
Guolin Ke committed
78
  }
79

Guolin Ke's avatar
Guolin Ke committed
80
  callback <- function(env) {
81

82
83
84
85
    # Check if rounds is null
    if (is.null(nrounds)) {
      init(env)
    }
86

87
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
88
    i <- env$iteration - env$begin_iteration
89

90
    # Apply list on parameters
Guolin Ke's avatar
Guolin Ke committed
91
    pars <- lapply(new_params, function(p) {
92
93
94
      if (is.function(p)) {
        return(p(i, nrounds))
      }
Guolin Ke's avatar
Guolin Ke committed
95
96
      p[i]
    })
97

98
99
100
    if (!is.null(env$model)) {
      env$model$reset_parameter(pars)
    }
101

Guolin Ke's avatar
Guolin Ke committed
102
  }
103

104
105
106
  attr(callback, "call") <- match.call()
  attr(callback, "is_pre_iteration") <- TRUE
  attr(callback, "name") <- "cb.reset.parameters"
107
  callback
Guolin Ke's avatar
Guolin Ke committed
108
109
110
}

# Format the evaluation metric string
111
format.eval.string <- function(eval_res, eval_err = NULL) {
112

113
  # Check for empty evaluation string
114
  if (is.null(eval_res) || length(eval_res) == 0L) {
115
116
    stop("no evaluation results")
  }
117

118
  # Check for empty evaluation error
Guolin Ke's avatar
Guolin Ke committed
119
  if (!is.null(eval_err)) {
120
    sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err)
Guolin Ke's avatar
Guolin Ke committed
121
  } else {
122
    sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value)
Guolin Ke's avatar
Guolin Ke committed
123
  }
124

Guolin Ke's avatar
Guolin Ke committed
125
126
}

127
merge.eval.string <- function(env) {
128

129
  # Check length of evaluation list
130
  if (length(env$eval_list) <= 0L) {
131
132
    return("")
  }
133

134
135
  # Get evaluation
  msg <- list(sprintf("[%d]:", env$iteration))
136

137
  # Set if evaluation error
138
  is_eval_err <- length(env$eval_err_list) > 0L
139

140
  # Loop through evaluation list
141
  for (j in seq_along(env$eval_list)) {
142

143
    # Store evaluation error
Guolin Ke's avatar
Guolin Ke committed
144
    eval_err <- NULL
145
146
147
    if (is_eval_err) {
      eval_err <- env$eval_err_list[[j]]
    }
148

149
    # Set error message
150
    msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err))
151

Guolin Ke's avatar
Guolin Ke committed
152
  }
153

154
155
  # Return tabulated separated message
  paste0(msg, collapse = "\t")
156

Guolin Ke's avatar
Guolin Ke committed
157
158
}

159
cb.print.evaluation <- function(period = 1L) {
160

161
  # Create callback
162
  callback <- function(env) {
163

164
    # Check if period is at least 1 or more
165
    if (period > 0L) {
166

167
      # Store iteration
Guolin Ke's avatar
Guolin Ke committed
168
      i <- env$iteration
169

170
      # Check if iteration matches moduo
171
      if ((i - 1L) %% period == 0L || is.element(i, c(env$begin_iteration, env$end_iteration))) {
172

173
        # Merge evaluation string
174
        msg <- merge.eval.string(env)
175

176
        # Check if message is existing
177
        if (nchar(msg) > 0L) {
178
179
          cat(merge.eval.string(env), "\n")
        }
180

Guolin Ke's avatar
Guolin Ke committed
181
      }
182

Guolin Ke's avatar
Guolin Ke committed
183
    }
184

Guolin Ke's avatar
Guolin Ke committed
185
  }
186

187
188
189
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.print.evaluation"
190

191
  # Return callback
192
  callback
193

Guolin Ke's avatar
Guolin Ke committed
194
195
196
}

cb.record.evaluation <- function() {
197

198
  # Create callback
199
  callback <- function(env) {
200

201
    # Return empty if empty evaluation list
202
    if (length(env$eval_list) <= 0L) {
203
204
      return()
    }
205

206
    # Set if evaluation error
207
    is_eval_err <- length(env$eval_err_list) > 0L
208

209
    # Check length of recorded evaluation
210
    if (length(env$model$record_evals) == 0L) {
211

212
      # Loop through each evaluation list element
213
      for (j in seq_along(env$eval_list)) {
214

215
        # Store names
Guolin Ke's avatar
Guolin Ke committed
216
        data_name <- env$eval_list[[j]]$data_name
217
        name <- env$eval_list[[j]]$name
Guolin Ke's avatar
Guolin Ke committed
218
        env$model$record_evals$start_iter <- env$begin_iteration
219

220
        # Check if evaluation record exists
221
        if (is.null(env$model$record_evals[[data_name]])) {
Guolin Ke's avatar
Guolin Ke committed
222
223
          env$model$record_evals[[data_name]] <- list()
        }
224

225
226
        # Create dummy lists
        env$model$record_evals[[data_name]][[name]] <- list()
227
228
        env$model$record_evals[[data_name]][[name]][[.EVAL_KEY()]] <- list()
        env$model$record_evals[[data_name]][[name]][[.EVAL_ERR_KEY()]] <- list()
229

Guolin Ke's avatar
Guolin Ke committed
230
      }
231

Guolin Ke's avatar
Guolin Ke committed
232
    }
233

234
    # Loop through each evaluation list element
235
    for (j in seq_along(env$eval_list)) {
236

237
      # Get evaluation data
Guolin Ke's avatar
Guolin Ke committed
238
239
      eval_res <- env$eval_list[[j]]
      eval_err <- NULL
240
241
242
      if (is_eval_err) {
        eval_err <- env$eval_err_list[[j]]
      }
243

244
      # Store names
Guolin Ke's avatar
Guolin Ke committed
245
      data_name <- eval_res$data_name
246
      name <- eval_res$name
247

248
      # Store evaluation data
249
250
      env$model$record_evals[[data_name]][[name]][[.EVAL_KEY()]] <- c(
        env$model$record_evals[[data_name]][[name]][[.EVAL_KEY()]]
251
252
        , eval_res$value
      )
253
254
      env$model$record_evals[[data_name]][[name]][[.EVAL_ERR_KEY()]] <- c(
        env$model$record_evals[[data_name]][[name]][[.EVAL_ERR_KEY()]]
255
256
        , eval_err
      )
257

Guolin Ke's avatar
Guolin Ke committed
258
    }
259

Guolin Ke's avatar
Guolin Ke committed
260
  }
261

262
263
264
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.record.evaluation"
265

266
  # Return callback
267
  callback
268

Guolin Ke's avatar
Guolin Ke committed
269
270
}

271
cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
272

273
  # Initialize variables
Guolin Ke's avatar
Guolin Ke committed
274
  factor_to_bigger_better <- NULL
275
276
277
278
  best_iter <- NULL
  best_score <- NULL
  best_msg <- NULL
  eval_len <- NULL
279

280
  # Initialization function
Guolin Ke's avatar
Guolin Ke committed
281
  init <- function(env) {
282

283
    # Early stopping cannot work without metrics
284
    if (length(env$eval_list) == 0L) {
Guolin Ke's avatar
Guolin Ke committed
285
      stop("For early stopping, valids must have at least one element")
286
    }
287

288
289
290
    # Store evaluation length
    eval_len <<- length(env$eval_list)

291
    # Check if verbose or not
292
    if (isTRUE(verbose)) {
293
      cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "")
294
    }
295

296
    # Internally treat everything as a maximization task
297
    factor_to_bigger_better <<- rep.int(1.0, eval_len)
298
    best_iter <<- rep.int(-1L, eval_len)
299
    best_score <<- rep.int(-Inf, eval_len)
300
    best_msg <<- list()
301

302
    # Loop through evaluation elements
303
    for (i in seq_len(eval_len)) {
304

305
      # Prepend message
Guolin Ke's avatar
Guolin Ke committed
306
      best_msg <<- c(best_msg, "")
307

308
309
      # Internally treat everything as a maximization task
      if (!isTRUE(env$eval_list[[i]]$higher_better)) {
Guolin Ke's avatar
Guolin Ke committed
310
311
        factor_to_bigger_better[i] <<- -1.0
      }
312

Guolin Ke's avatar
Guolin Ke committed
313
    }
314

Guolin Ke's avatar
Guolin Ke committed
315
  }
316

317
  # Create callback
Guolin Ke's avatar
Guolin Ke committed
318
  callback <- function(env, finalize = FALSE) {
319

320
321
322
323
    # Check for empty evaluation
    if (is.null(eval_len)) {
      init(env)
    }
324

325
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
326
    cur_iter <- env$iteration
327

328
    # Loop through evaluation
329
    for (i in seq_len(eval_len)) {
330

331
      # Store score
Guolin Ke's avatar
Guolin Ke committed
332
      score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
333

334
335
        # Check if score is better
        if (score > best_score[i]) {
336

337
338
339
          # Store new scores
          best_score[i] <<- score
          best_iter[i] <<- cur_iter
340

341
342
343
          # Prepare to print if verbose
          if (verbose) {
            best_msg[[i]] <<- as.character(merge.eval.string(env))
344
          }
345

346
        } else {
347

348
349
          # Check if early stopping is required
          if (cur_iter - best_iter[i] >= stopping_rounds) {
350

351
352
353
354
355
            # Check if model is not null
            if (!is.null(env$model)) {
              env$model$best_score <- best_score[i]
              env$model$best_iter <- best_iter[i]
            }
356

357
358
            # Print message if verbose
            if (isTRUE(verbose)) {
359

360
361
              cat("Early stopping, best iteration is:", "\n")
              cat(best_msg[[i]], "\n")
362

363
            }
364

365
366
367
            # Store best iteration and stop
            env$best_iter <- best_iter[i]
            env$met_early_stop <- TRUE
Guolin Ke's avatar
Guolin Ke committed
368
          }
369

Guolin Ke's avatar
Guolin Ke committed
370
        }
371

Guolin Ke's avatar
Guolin Ke committed
372
      if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) {
373
374
375
376
377
        # Check if model is not null
        if (!is.null(env$model)) {
          env$model$best_score <- best_score[i]
          env$model$best_iter <- best_iter[i]
        }
378

379
380
381
382
383
        # Print message if verbose
        if (isTRUE(verbose)) {
          cat("Did not meet early stopping, best iteration is:", "\n")
          cat(best_msg[[i]], "\n")
        }
384

385
386
387
388
        # Store best iteration and stop
        env$best_iter <- best_iter[i]
        env$met_early_stop <- TRUE
      }
Guolin Ke's avatar
Guolin Ke committed
389
390
    }
  }
391

392
393
394
  # Set attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.early.stop"
395

396
  # Return callback
397
  callback
398

Guolin Ke's avatar
Guolin Ke committed
399
400
401
}

# Extract callback names from the list of callbacks
402
403
404
callback.names <- function(cb_list) {
  unlist(lapply(cb_list, attr, "name"))
}
Guolin Ke's avatar
Guolin Ke committed
405
406

add.cb <- function(cb_list, cb) {
407

408
  # Combine two elements
Guolin Ke's avatar
Guolin Ke committed
409
  cb_list <- c(cb_list, cb)
410

411
  # Set names of elements
Guolin Ke's avatar
Guolin Ke committed
412
  names(cb_list) <- callback.names(cb_list)
413

414
415
  # Check for existence
  if ("cb.early.stop" %in% names(cb_list)) {
416

417
418
    # Concatenate existing elements
    cb_list <- c(cb_list, cb_list["cb.early.stop"])
419

420
421
    # Remove only the first one
    cb_list["cb.early.stop"] <- NULL
422

Guolin Ke's avatar
Guolin Ke committed
423
  }
424

425
  # Return element
Guolin Ke's avatar
Guolin Ke committed
426
  cb_list
427

Guolin Ke's avatar
Guolin Ke committed
428
429
430
}

categorize.callbacks <- function(cb_list) {
431

432
  # Check for pre-iteration or post-iteration
Guolin Ke's avatar
Guolin Ke committed
433
434
  list(
    pre_iter = Filter(function(x) {
435
436
437
      pre <- attr(x, "is_pre_iteration")
      !is.null(pre) && pre
    }, cb_list),
Guolin Ke's avatar
Guolin Ke committed
438
    post_iter = Filter(function(x) {
439
440
441
      pre <- attr(x, "is_pre_iteration")
      is.null(pre) || !pre
    }, cb_list)
Guolin Ke's avatar
Guolin Ke committed
442
  )
443

Guolin Ke's avatar
Guolin Ke committed
444
}