callback.R 10.1 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
#' @importFrom R6 R6Class
CB_ENV <- R6::R6Class(
Guolin Ke's avatar
Guolin Ke committed
3
  "lgb.cb_env",
4
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
5
  public = list(
6
7
    model = NULL,
    iteration = NULL,
8
    begin_iteration = NULL,
9
10
11
12
    end_iteration = NULL,
    eval_list = list(),
    eval_err_list = list(),
    best_iter = -1,
13
    best_score = NA,
14
    met_early_stop = FALSE
Guolin Ke's avatar
Guolin Ke committed
15
16
17
18
  )
)

cb.reset.parameters <- function(new_params) {
19

20
21
22
23
  # Check for parameter list
  if (!is.list(new_params)) {
    stop(sQuote("new_params"), " must be a list")
  }
24

25
  # Deparse parameter list
26
  pnames  <- gsub("\\.", "_", names(new_params))
Guolin Ke's avatar
Guolin Ke committed
27
  nrounds <- NULL
28

29
  # Run some checks in the beginning
Guolin Ke's avatar
Guolin Ke committed
30
  init <- function(env) {
31

32
    # Store boosting rounds
Guolin Ke's avatar
Guolin Ke committed
33
    nrounds <<- env$end_iteration - env$begin_iteration + 1
34

35
    # Check for model environment
36
    if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) }
37

Guolin Ke's avatar
Guolin Ke committed
38
39
    # Some parameters are not allowed to be changed,
    # since changing them would simply wreck some chaos
40
41
    not_allowed <- c("num_class", "metric", "boosting_type")
    if (any(pnames %in% not_allowed)) {
42
43
44
45
46
      stop(
        "Parameters "
        , paste0(pnames[pnames %in% not_allowed], collapse = ", ")
        , " cannot be changed during boosting"
      )
47
    }
48

49
    # Check parameter names
Guolin Ke's avatar
Guolin Ke committed
50
    for (n in pnames) {
51

52
      # Set name
Guolin Ke's avatar
Guolin Ke committed
53
      p <- new_params[[n]]
54

55
      # Check if function for parameter
Guolin Ke's avatar
Guolin Ke committed
56
      if (is.function(p)) {
57

58
59
        # Check if requires at least two arguments
        if (length(formals(p)) != 2) {
60
          stop("Parameter ", sQuote(n), " is a function but not of two arguments")
61
        }
62

63
        # Check if numeric or character
Guolin Ke's avatar
Guolin Ke committed
64
      } else if (is.numeric(p) || is.character(p)) {
65

66
67
        # Check if length is matching
        if (length(p) != nrounds) {
68
          stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds"))
69
        }
70

Guolin Ke's avatar
Guolin Ke committed
71
      } else {
72

73
        stop("Parameter ", sQuote(n), " is not a function or a vector")
74

Guolin Ke's avatar
Guolin Ke committed
75
      }
76

Guolin Ke's avatar
Guolin Ke committed
77
    }
78

Guolin Ke's avatar
Guolin Ke committed
79
  }
80

Guolin Ke's avatar
Guolin Ke committed
81
  callback <- function(env) {
82

83
84
85
86
    # Check if rounds is null
    if (is.null(nrounds)) {
      init(env)
    }
87

88
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
89
    i <- env$iteration - env$begin_iteration
90

91
    # Apply list on parameters
Guolin Ke's avatar
Guolin Ke committed
92
    pars <- lapply(new_params, function(p) {
93
94
95
      if (is.function(p)) {
        return(p(i, nrounds))
      }
Guolin Ke's avatar
Guolin Ke committed
96
97
      p[i]
    })
98

99
100
101
102
    # To-do check pars
    if (!is.null(env$model)) {
      env$model$reset_parameter(pars)
    }
103

Guolin Ke's avatar
Guolin Ke committed
104
  }
105

106
107
108
  attr(callback, "call") <- match.call()
  attr(callback, "is_pre_iteration") <- TRUE
  attr(callback, "name") <- "cb.reset.parameters"
109
  callback
Guolin Ke's avatar
Guolin Ke committed
110
111
112
}

# Format the evaluation metric string
113
format.eval.string <- function(eval_res, eval_err = NULL) {
114

115
116
117
118
  # Check for empty evaluation string
  if (is.null(eval_res) || length(eval_res) == 0) {
    stop("no evaluation results")
  }
119

120
  # Check for empty evaluation error
Guolin Ke's avatar
Guolin Ke committed
121
  if (!is.null(eval_err)) {
122
    sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err)
Guolin Ke's avatar
Guolin Ke committed
123
  } else {
124
    sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value)
Guolin Ke's avatar
Guolin Ke committed
125
  }
126

Guolin Ke's avatar
Guolin Ke committed
127
128
}

129
merge.eval.string <- function(env) {
130

131
132
133
134
  # Check length of evaluation list
  if (length(env$eval_list) <= 0) {
    return("")
  }
135

136
137
  # Get evaluation
  msg <- list(sprintf("[%d]:", env$iteration))
138

139
  # Set if evaluation error
Bernie Gray's avatar
Bernie Gray committed
140
  is_eval_err <- length(env$eval_err_list) > 0
141

142
  # Loop through evaluation list
143
  for (j in seq_along(env$eval_list)) {
144

145
    # Store evaluation error
Guolin Ke's avatar
Guolin Ke committed
146
    eval_err <- NULL
147
148
149
    if (is_eval_err) {
      eval_err <- env$eval_err_list[[j]]
    }
150

151
    # Set error message
152
    msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err))
153

Guolin Ke's avatar
Guolin Ke committed
154
  }
155

156
157
  # Return tabulated separated message
  paste0(msg, collapse = "\t")
158

Guolin Ke's avatar
Guolin Ke committed
159
160
}

161
cb.print.evaluation <- function(period = 1) {
162

163
  # Create callback
164
  callback <- function(env) {
165

166
    # Check if period is at least 1 or more
167
    if (period > 0) {
168

169
      # Store iteration
Guolin Ke's avatar
Guolin Ke committed
170
      i <- env$iteration
171

172
      # Check if iteration matches moduo
173
      if ( (i - 1) %% period == 0 || is.element(i, c(env$begin_iteration, env$end_iteration))) {
174

175
        # Merge evaluation string
176
        msg <- merge.eval.string(env)
177

178
179
180
181
        # Check if message is existing
        if (nchar(msg) > 0) {
          cat(merge.eval.string(env), "\n")
        }
182

Guolin Ke's avatar
Guolin Ke committed
183
      }
184

Guolin Ke's avatar
Guolin Ke committed
185
    }
186

Guolin Ke's avatar
Guolin Ke committed
187
  }
188

189
190
191
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.print.evaluation"
192

193
  # Return callback
194
  callback
195

Guolin Ke's avatar
Guolin Ke committed
196
197
198
}

cb.record.evaluation <- function() {
199

200
  # Create callback
201
  callback <- function(env) {
202

203
204
205
206
    # Return empty if empty evaluation list
    if (length(env$eval_list) <= 0) {
      return()
    }
207

208
    # Set if evaluation error
209
    is_eval_err <- length(env$eval_err_list) > 0
210

211
    # Check length of recorded evaluation
212
    if (length(env$model$record_evals) == 0) {
213

214
      # Loop through each evaluation list element
215
      for (j in seq_along(env$eval_list)) {
216

217
        # Store names
Guolin Ke's avatar
Guolin Ke committed
218
        data_name <- env$eval_list[[j]]$data_name
219
        name <- env$eval_list[[j]]$name
Guolin Ke's avatar
Guolin Ke committed
220
        env$model$record_evals$start_iter <- env$begin_iteration
221

222
        # Check if evaluation record exists
223
        if (is.null(env$model$record_evals[[data_name]])) {
Guolin Ke's avatar
Guolin Ke committed
224
225
          env$model$record_evals[[data_name]] <- list()
        }
226

227
228
229
        # Create dummy lists
        env$model$record_evals[[data_name]][[name]] <- list()
        env$model$record_evals[[data_name]][[name]]$eval <- list()
Guolin Ke's avatar
Guolin Ke committed
230
        env$model$record_evals[[data_name]][[name]]$eval_err <- list()
231

Guolin Ke's avatar
Guolin Ke committed
232
      }
233

Guolin Ke's avatar
Guolin Ke committed
234
    }
235

236
    # Loop through each evaluation list element
237
    for (j in seq_along(env$eval_list)) {
238

239
      # Get evaluation data
Guolin Ke's avatar
Guolin Ke committed
240
241
      eval_res <- env$eval_list[[j]]
      eval_err <- NULL
242
243
244
      if (is_eval_err) {
        eval_err <- env$eval_err_list[[j]]
      }
245

246
      # Store names
Guolin Ke's avatar
Guolin Ke committed
247
      data_name <- eval_res$data_name
248
      name <- eval_res$name
249

250
      # Store evaluation data
251
252
253
254
255
256
257
258
      env$model$record_evals[[data_name]][[name]]$eval <- c(
        env$model$record_evals[[data_name]][[name]]$eval
        , eval_res$value
      )
      env$model$record_evals[[data_name]][[name]]$eval_err <- c(
        env$model$record_evals[[data_name]][[name]]$eval_err
        , eval_err
      )
259

Guolin Ke's avatar
Guolin Ke committed
260
    }
261

Guolin Ke's avatar
Guolin Ke committed
262
  }
263

264
265
266
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.record.evaluation"
267

268
  # Return callback
269
  callback
270

Guolin Ke's avatar
Guolin Ke committed
271
272
}

273
cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
274

275
  # Initialize variables
Guolin Ke's avatar
Guolin Ke committed
276
  factor_to_bigger_better <- NULL
277
278
279
280
  best_iter <- NULL
  best_score <- NULL
  best_msg <- NULL
  eval_len <- NULL
281

282
  # Initialization function
Guolin Ke's avatar
Guolin Ke committed
283
  init <- function(env) {
284

285
    # Store evaluation length
286
    eval_len <<- length(env$eval_list)
287

288
    # Early stopping cannot work without metrics
289
    if (eval_len == 0) {
Guolin Ke's avatar
Guolin Ke committed
290
      stop("For early stopping, valids must have at least one element")
291
    }
292

293
    # Check if verbose or not
294
    if (isTRUE(verbose)) {
295
      cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "")
296
    }
297

298
    # Maximization or minimization task
299
300
301
    factor_to_bigger_better <<- rep.int(1.0, eval_len)
    best_iter <<- rep.int(-1, eval_len)
    best_score <<- rep.int(-Inf, eval_len)
302
    best_msg <<- list()
303

304
    # Loop through evaluation elements
305
    for (i in seq_len(eval_len)) {
306

307
      # Prepend message
Guolin Ke's avatar
Guolin Ke committed
308
      best_msg <<- c(best_msg, "")
309

310
      # Check if maximization or minimization
311
      if (!env$eval_list[[i]]$higher_better) {
Guolin Ke's avatar
Guolin Ke committed
312
313
        factor_to_bigger_better[i] <<- -1.0
      }
314

Guolin Ke's avatar
Guolin Ke committed
315
    }
316

Guolin Ke's avatar
Guolin Ke committed
317
  }
318

319
  # Create callback
Guolin Ke's avatar
Guolin Ke committed
320
  callback <- function(env, finalize = FALSE) {
321

322
323
324
325
    # Check for empty evaluation
    if (is.null(eval_len)) {
      init(env)
    }
326

327
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
328
    cur_iter <- env$iteration
329

330
    # Loop through evaluation
331
    for (i in seq_len(eval_len)) {
332

333
      # Store score
Guolin Ke's avatar
Guolin Ke committed
334
      score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
335

336
337
        # Check if score is better
        if (score > best_score[i]) {
338

339
340
341
          # Store new scores
          best_score[i] <<- score
          best_iter[i] <<- cur_iter
342

343
344
345
          # Prepare to print if verbose
          if (verbose) {
            best_msg[[i]] <<- as.character(merge.eval.string(env))
346
          }
347

348
        } else {
349

350
351
          # Check if early stopping is required
          if (cur_iter - best_iter[i] >= stopping_rounds) {
352

353
354
355
356
357
            # Check if model is not null
            if (!is.null(env$model)) {
              env$model$best_score <- best_score[i]
              env$model$best_iter <- best_iter[i]
            }
358

359
360
            # Print message if verbose
            if (isTRUE(verbose)) {
361

362
363
              cat("Early stopping, best iteration is:", "\n")
              cat(best_msg[[i]], "\n")
364

365
            }
366

367
368
369
            # Store best iteration and stop
            env$best_iter <- best_iter[i]
            env$met_early_stop <- TRUE
Guolin Ke's avatar
Guolin Ke committed
370
          }
371

Guolin Ke's avatar
Guolin Ke committed
372
        }
373

Guolin Ke's avatar
Guolin Ke committed
374
      if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) {
375
376
377
378
379
        # Check if model is not null
        if (!is.null(env$model)) {
          env$model$best_score <- best_score[i]
          env$model$best_iter <- best_iter[i]
        }
380

381
382
383
384
385
        # Print message if verbose
        if (isTRUE(verbose)) {
          cat("Did not meet early stopping, best iteration is:", "\n")
          cat(best_msg[[i]], "\n")
        }
386

387
388
389
390
        # Store best iteration and stop
        env$best_iter <- best_iter[i]
        env$met_early_stop <- TRUE
      }
Guolin Ke's avatar
Guolin Ke committed
391
392
    }
  }
393

394
395
396
  # Set attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.early.stop"
397

398
  # Return callback
399
  callback
400

Guolin Ke's avatar
Guolin Ke committed
401
402
403
}

# Extract callback names from the list of callbacks
404
405
406
callback.names <- function(cb_list) {
  unlist(lapply(cb_list, attr, "name"))
}
Guolin Ke's avatar
Guolin Ke committed
407
408

add.cb <- function(cb_list, cb) {
409

410
  # Combine two elements
Guolin Ke's avatar
Guolin Ke committed
411
  cb_list <- c(cb_list, cb)
412

413
  # Set names of elements
Guolin Ke's avatar
Guolin Ke committed
414
  names(cb_list) <- callback.names(cb_list)
415

416
417
  # Check for existence
  if ("cb.early.stop" %in% names(cb_list)) {
418

419
420
    # Concatenate existing elements
    cb_list <- c(cb_list, cb_list["cb.early.stop"])
421

422
423
    # Remove only the first one
    cb_list["cb.early.stop"] <- NULL
424

Guolin Ke's avatar
Guolin Ke committed
425
  }
426

427
  # Return element
Guolin Ke's avatar
Guolin Ke committed
428
  cb_list
429

Guolin Ke's avatar
Guolin Ke committed
430
431
432
}

categorize.callbacks <- function(cb_list) {
433

434
  # Check for pre-iteration or post-iteration
Guolin Ke's avatar
Guolin Ke committed
435
436
  list(
    pre_iter = Filter(function(x) {
437
438
439
      pre <- attr(x, "is_pre_iteration")
      !is.null(pre) && pre
    }, cb_list),
Guolin Ke's avatar
Guolin Ke committed
440
    post_iter = Filter(function(x) {
441
442
443
      pre <- attr(x, "is_pre_iteration")
      is.null(pre) || !pre
    }, cb_list)
Guolin Ke's avatar
Guolin Ke committed
444
  )
445

Guolin Ke's avatar
Guolin Ke committed
446
}