callback.R 8.36 KB
Newer Older
1
2
3
4
5
6
7
8
# constants that control naming in lists
.EVAL_KEY <- function() {
  return("eval")
}
.EVAL_ERR_KEY <- function() {
  return("eval_err")
}

James Lamb's avatar
James Lamb committed
9
10
#' @importFrom R6 R6Class
CB_ENV <- R6::R6Class(
Guolin Ke's avatar
Guolin Ke committed
11
  "lgb.cb_env",
12
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
13
  public = list(
14
15
    model = NULL,
    iteration = NULL,
16
    begin_iteration = NULL,
17
18
19
    end_iteration = NULL,
    eval_list = list(),
    eval_err_list = list(),
20
    best_iter = -1L,
21
    best_score = NA,
22
    met_early_stop = FALSE
Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
  )
)

# Format the evaluation metric string
27
.format_eval_string <- function(eval_res, eval_err) {
28

29
  # Check for empty evaluation string
30
  if (is.null(eval_res) || length(eval_res) == 0L) {
31
32
    stop("no evaluation results")
  }
33

34
  # Check for empty evaluation error
Guolin Ke's avatar
Guolin Ke committed
35
  if (!is.null(eval_err)) {
36
    return(sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err))
Guolin Ke's avatar
Guolin Ke committed
37
  } else {
38
    return(sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value))
Guolin Ke's avatar
Guolin Ke committed
39
  }
40

Guolin Ke's avatar
Guolin Ke committed
41
42
}

43
.merge_eval_string <- function(env) {
44

45
  # Check length of evaluation list
46
  if (length(env$eval_list) <= 0L) {
47
48
    return("")
  }
49

50
51
  # Get evaluation
  msg <- list(sprintf("[%d]:", env$iteration))
52

53
  # Set if evaluation error
54
  is_eval_err <- length(env$eval_err_list) > 0L
55

56
  # Loop through evaluation list
57
  for (j in seq_along(env$eval_list)) {
58

59
    # Store evaluation error
Guolin Ke's avatar
Guolin Ke committed
60
    eval_err <- NULL
61
    if (isTRUE(is_eval_err)) {
62
63
      eval_err <- env$eval_err_list[[j]]
    }
64

65
    # Set error message
66
    msg <- c(msg, .format_eval_string(eval_res = env$eval_list[[j]], eval_err = eval_err))
67

Guolin Ke's avatar
Guolin Ke committed
68
  }
69

70
  return(paste0(msg, collapse = "  "))
71

Guolin Ke's avatar
Guolin Ke committed
72
73
}

74
cb_print_evaluation <- function(period) {
75

76
  # Create callback
77
  callback <- function(env) {
78

79
    # Check if period is at least 1 or more
80
    if (period > 0L) {
81

82
      # Store iteration
Guolin Ke's avatar
Guolin Ke committed
83
      i <- env$iteration
84

85
      # Check if iteration matches moduo
86
      if ((i - 1L) %% period == 0L || is.element(i, c(env$begin_iteration, env$end_iteration))) {
87

88
        # Merge evaluation string
89
        msg <- .merge_eval_string(env = env)
90

91
        # Check if message is existing
92
        if (nchar(msg) > 0L) {
93
          print(.merge_eval_string(env = env))
94
        }
95

Guolin Ke's avatar
Guolin Ke committed
96
      }
97

Guolin Ke's avatar
Guolin Ke committed
98
    }
99

100
101
    return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
102
  }
103

104
105
  # Store attributes
  attr(callback, "call") <- match.call()
106
  attr(callback, "name") <- "cb_print_evaluation"
107

108
  return(callback)
109

Guolin Ke's avatar
Guolin Ke committed
110
111
}

112
cb_record_evaluation <- function() {
113

114
  # Create callback
115
  callback <- function(env) {
116

117
    if (length(env$eval_list) <= 0L) {
118
119
      return()
    }
120

121
    # Set if evaluation error
122
    is_eval_err <- length(env$eval_err_list) > 0L
123

124
    # Check length of recorded evaluation
125
    if (length(env$model$record_evals) == 0L) {
126

127
      # Loop through each evaluation list element
128
      for (j in seq_along(env$eval_list)) {
129

130
        # Store names
Guolin Ke's avatar
Guolin Ke committed
131
        data_name <- env$eval_list[[j]]$data_name
132
        name <- env$eval_list[[j]]$name
Guolin Ke's avatar
Guolin Ke committed
133
        env$model$record_evals$start_iter <- env$begin_iteration
134

135
        # Check if evaluation record exists
136
        if (is.null(env$model$record_evals[[data_name]])) {
Guolin Ke's avatar
Guolin Ke committed
137
138
          env$model$record_evals[[data_name]] <- list()
        }
139

140
141
        # Create dummy lists
        env$model$record_evals[[data_name]][[name]] <- list()
142
143
        env$model$record_evals[[data_name]][[name]][[.EVAL_KEY()]] <- list()
        env$model$record_evals[[data_name]][[name]][[.EVAL_ERR_KEY()]] <- list()
144

Guolin Ke's avatar
Guolin Ke committed
145
      }
146

Guolin Ke's avatar
Guolin Ke committed
147
    }
148

149
    # Loop through each evaluation list element
150
    for (j in seq_along(env$eval_list)) {
151

152
      # Get evaluation data
Guolin Ke's avatar
Guolin Ke committed
153
154
      eval_res <- env$eval_list[[j]]
      eval_err <- NULL
155
      if (isTRUE(is_eval_err)) {
156
157
        eval_err <- env$eval_err_list[[j]]
      }
158

159
      # Store names
Guolin Ke's avatar
Guolin Ke committed
160
      data_name <- eval_res$data_name
161
      name <- eval_res$name
162

163
      # Store evaluation data
164
165
      env$model$record_evals[[data_name]][[name]][[.EVAL_KEY()]] <- c(
        env$model$record_evals[[data_name]][[name]][[.EVAL_KEY()]]
166
167
        , eval_res$value
      )
168
169
      env$model$record_evals[[data_name]][[name]][[.EVAL_ERR_KEY()]] <- c(
        env$model$record_evals[[data_name]][[name]][[.EVAL_ERR_KEY()]]
170
171
        , eval_err
      )
172

Guolin Ke's avatar
Guolin Ke committed
173
    }
174

175
176
    return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
177
  }
178

179
180
  # Store attributes
  attr(callback, "call") <- match.call()
181
  attr(callback, "name") <- "cb_record_evaluation"
182

183
  return(callback)
184

Guolin Ke's avatar
Guolin Ke committed
185
186
}

187
cb_early_stop <- function(stopping_rounds, first_metric_only, verbose) {
188

Guolin Ke's avatar
Guolin Ke committed
189
  factor_to_bigger_better <- NULL
190
191
192
193
  best_iter <- NULL
  best_score <- NULL
  best_msg <- NULL
  eval_len <- NULL
194

195
  # Initialization function
Guolin Ke's avatar
Guolin Ke committed
196
  init <- function(env) {
197

198
    # Early stopping cannot work without metrics
199
    if (length(env$eval_list) == 0L) {
Guolin Ke's avatar
Guolin Ke committed
200
      stop("For early stopping, valids must have at least one element")
201
    }
202

203
204
205
    # Store evaluation length
    eval_len <<- length(env$eval_list)

206
    # Check if verbose or not
207
    if (isTRUE(verbose)) {
208
209
210
211
212
213
      msg <- paste0(
        "Will train until there is no improvement in "
        , stopping_rounds
        , " rounds."
      )
      print(msg)
214
    }
215

216
    # Internally treat everything as a maximization task
217
    factor_to_bigger_better <<- rep.int(1.0, eval_len)
218
    best_iter <<- rep.int(-1L, eval_len)
219
    best_score <<- rep.int(-Inf, eval_len)
220
    best_msg <<- list()
221

222
    # Loop through evaluation elements
223
    for (i in seq_len(eval_len)) {
224

225
      # Prepend message
Guolin Ke's avatar
Guolin Ke committed
226
      best_msg <<- c(best_msg, "")
227

228
229
      # Internally treat everything as a maximization task
      if (!isTRUE(env$eval_list[[i]]$higher_better)) {
Guolin Ke's avatar
Guolin Ke committed
230
231
        factor_to_bigger_better[i] <<- -1.0
      }
232

Guolin Ke's avatar
Guolin Ke committed
233
    }
234

235
236
    return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
237
  }
238

239
  # Create callback
240
  callback <- function(env) {
241

242
243
    # Check for empty evaluation
    if (is.null(eval_len)) {
244
      init(env = env)
245
    }
246

247
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
248
    cur_iter <- env$iteration
249

250
251
252
253
254
255
256
257
    # By default, any metric can trigger early stopping. This can be disabled
    # with 'first_metric_only = TRUE'
    if (isTRUE(first_metric_only)) {
      evals_to_check <- 1L
    } else {
      evals_to_check <- seq_len(eval_len)
    }

258
    # Loop through evaluation
259
    for (i in evals_to_check) {
260

261
      # Store score
Guolin Ke's avatar
Guolin Ke committed
262
      score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
263

264
265
        # Check if score is better
        if (score > best_score[i]) {
266

267
268
269
          # Store new scores
          best_score[i] <<- score
          best_iter[i] <<- cur_iter
270

271
272
          # Prepare to print if verbose
          if (verbose) {
273
            best_msg[[i]] <<- as.character(.merge_eval_string(env = env))
274
          }
275

276
        } else {
277

278
279
          # Check if early stopping is required
          if (cur_iter - best_iter[i] >= stopping_rounds) {
280

281
282
283
284
            if (!is.null(env$model)) {
              env$model$best_score <- best_score[i]
              env$model$best_iter <- best_iter[i]
            }
285

286
            if (isTRUE(verbose)) {
287
              print(paste0("Early stopping, best iteration is: ", best_msg[[i]]))
288
            }
289

290
291
292
            # Store best iteration and stop
            env$best_iter <- best_iter[i]
            env$met_early_stop <- TRUE
Guolin Ke's avatar
Guolin Ke committed
293
          }
294

Guolin Ke's avatar
Guolin Ke committed
295
        }
296

Guolin Ke's avatar
Guolin Ke committed
297
      if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) {
298

299
300
301
302
        if (!is.null(env$model)) {
          env$model$best_score <- best_score[i]
          env$model$best_iter <- best_iter[i]
        }
303

304
        if (isTRUE(verbose)) {
305
          print(paste0("Did not meet early stopping, best iteration is: ", best_msg[[i]]))
306
        }
307

308
309
310
311
        # Store best iteration and stop
        env$best_iter <- best_iter[i]
        env$met_early_stop <- TRUE
      }
Guolin Ke's avatar
Guolin Ke committed
312
    }
313
314
315

    return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
316
  }
317

318
  attr(callback, "call") <- match.call()
319
  attr(callback, "name") <- "cb_early_stop"
320

321
  return(callback)
322

Guolin Ke's avatar
Guolin Ke committed
323
324
325
}

# Extract callback names from the list of callbacks
326
callback.names <- function(cb_list) {
327
  return(unlist(lapply(cb_list, attr, "name")))
328
}
Guolin Ke's avatar
Guolin Ke committed
329
330

add.cb <- function(cb_list, cb) {
331

332
  # Combine two elements
Guolin Ke's avatar
Guolin Ke committed
333
  cb_list <- c(cb_list, cb)
334

335
  # Set names of elements
336
  names(cb_list) <- callback.names(cb_list = cb_list)
337

338
  if ("cb_early_stop" %in% names(cb_list)) {
339

340
    # Concatenate existing elements
341
    cb_list <- c(cb_list, cb_list["cb_early_stop"])
342

343
    # Remove only the first one
344
    cb_list["cb_early_stop"] <- NULL
345

Guolin Ke's avatar
Guolin Ke committed
346
  }
347

348
  return(cb_list)
349

Guolin Ke's avatar
Guolin Ke committed
350
351
352
}

categorize.callbacks <- function(cb_list) {
353

354
  # Check for pre-iteration or post-iteration
355
356
357
358
359
360
361
362
363
364
365
  return(
    list(
      pre_iter = Filter(function(x) {
        pre <- attr(x, "is_pre_iteration")
        !is.null(pre) && pre
      }, cb_list),
      post_iter = Filter(function(x) {
        pre <- attr(x, "is_pre_iteration")
        is.null(pre) || !pre
      }, cb_list)
    )
Guolin Ke's avatar
Guolin Ke committed
366
  )
367

Guolin Ke's avatar
Guolin Ke committed
368
}