callback.R 10.6 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
#' @importFrom R6 R6Class
CB_ENV <- R6::R6Class(
Guolin Ke's avatar
Guolin Ke committed
3
  "lgb.cb_env",
4
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
5
  public = list(
6
7
    model = NULL,
    iteration = NULL,
8
    begin_iteration = NULL,
9
10
11
12
    end_iteration = NULL,
    eval_list = list(),
    eval_err_list = list(),
    best_iter = -1,
13
    best_score = NA,
14
    met_early_stop = FALSE
Guolin Ke's avatar
Guolin Ke committed
15
16
17
18
  )
)

cb.reset.parameters <- function(new_params) {
19
  
20
21
22
23
  # Check for parameter list
  if (!is.list(new_params)) {
    stop(sQuote("new_params"), " must be a list")
  }
24
  
25
  # Deparse parameter list
26
  pnames  <- gsub("\\.", "_", names(new_params))
Guolin Ke's avatar
Guolin Ke committed
27
  nrounds <- NULL
28
  
29
  # Run some checks in the beginning
Guolin Ke's avatar
Guolin Ke committed
30
  init <- function(env) {
31
    
32
    # Store boosting rounds
Guolin Ke's avatar
Guolin Ke committed
33
    nrounds <<- env$end_iteration - env$begin_iteration + 1
34
    
35
    # Check for model environment
36
    if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) }
37
    
Guolin Ke's avatar
Guolin Ke committed
38
39
    # Some parameters are not allowed to be changed,
    # since changing them would simply wreck some chaos
40
41
42
43
    not_allowed <- c("num_class", "metric", "boosting_type")
    if (any(pnames %in% not_allowed)) {
      stop("Parameters ", paste0(pnames[pnames %in% not_allowed], collapse = ", "), " cannot be changed during boosting")
    }
44
    
45
    # Check parameter names
Guolin Ke's avatar
Guolin Ke committed
46
    for (n in pnames) {
47
      
48
      # Set name
Guolin Ke's avatar
Guolin Ke committed
49
      p <- new_params[[n]]
50
      
51
      # Check if function for parameter
Guolin Ke's avatar
Guolin Ke committed
52
      if (is.function(p)) {
53
        
54
55
        # Check if requires at least two arguments
        if (length(formals(p)) != 2) {
56
          stop("Parameter ", sQuote(n), " is a function but not of two arguments")
57
        }
58
        
59
        # Check if numeric or character
Guolin Ke's avatar
Guolin Ke committed
60
      } else if (is.numeric(p) || is.character(p)) {
61
        
62
63
        # Check if length is matching
        if (length(p) != nrounds) {
64
          stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds"))
65
        }
66
        
Guolin Ke's avatar
Guolin Ke committed
67
      } else {
68
        
69
        stop("Parameter ", sQuote(n), " is not a function or a vector")
70
        
Guolin Ke's avatar
Guolin Ke committed
71
      }
72
      
Guolin Ke's avatar
Guolin Ke committed
73
    }
74
    
Guolin Ke's avatar
Guolin Ke committed
75
  }
76
  
Guolin Ke's avatar
Guolin Ke committed
77
  callback <- function(env) {
78
    
79
80
81
82
    # Check if rounds is null
    if (is.null(nrounds)) {
      init(env)
    }
83
    
84
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
85
    i <- env$iteration - env$begin_iteration
86
    
87
    # Apply list on parameters
Guolin Ke's avatar
Guolin Ke committed
88
    pars <- lapply(new_params, function(p) {
89
90
91
      if (is.function(p)) {
        return(p(i, nrounds))
      }
Guolin Ke's avatar
Guolin Ke committed
92
93
      p[i]
    })
94
    
95
96
97
98
    # To-do check pars
    if (!is.null(env$model)) {
      env$model$reset_parameter(pars)
    }
99
    
Guolin Ke's avatar
Guolin Ke committed
100
  }
101
  
102
103
104
  attr(callback, "call") <- match.call()
  attr(callback, "is_pre_iteration") <- TRUE
  attr(callback, "name") <- "cb.reset.parameters"
105
  callback
Guolin Ke's avatar
Guolin Ke committed
106
107
108
}

# Format the evaluation metric string
109
format.eval.string <- function(eval_res, eval_err = NULL) {
110
  
111
112
113
114
  # Check for empty evaluation string
  if (is.null(eval_res) || length(eval_res) == 0) {
    stop("no evaluation results")
  }
115
  
116
  # Check for empty evaluation error
Guolin Ke's avatar
Guolin Ke committed
117
  if (!is.null(eval_err)) {
118
    sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err)
Guolin Ke's avatar
Guolin Ke committed
119
  } else {
120
    sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value)
Guolin Ke's avatar
Guolin Ke committed
121
  }
122
  
Guolin Ke's avatar
Guolin Ke committed
123
124
}

125
merge.eval.string <- function(env) {
126
  
127
128
129
130
  # Check length of evaluation list
  if (length(env$eval_list) <= 0) {
    return("")
  }
131
  
132
133
  # Get evaluation
  msg <- list(sprintf("[%d]:", env$iteration))
134
  
135
  # Set if evaluation error
Bernie Gray's avatar
Bernie Gray committed
136
  is_eval_err <- length(env$eval_err_list) > 0
137
  
138
  # Loop through evaluation list
139
  for (j in seq_along(env$eval_list)) {
140
    
141
    # Store evaluation error
Guolin Ke's avatar
Guolin Ke committed
142
    eval_err <- NULL
143
144
145
    if (is_eval_err) {
      eval_err <- env$eval_err_list[[j]]
    }
146
    
147
    # Set error message
148
    msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err))
149
    
Guolin Ke's avatar
Guolin Ke committed
150
  }
151
  
152
153
  # Return tabulated separated message
  paste0(msg, collapse = "\t")
154
  
Guolin Ke's avatar
Guolin Ke committed
155
156
}

157
cb.print.evaluation <- function(period = 1) {
158
  
159
  # Create callback
160
  callback <- function(env) {
161
    
162
    # Check if period is at least 1 or more
163
    if (period > 0) {
164
      
165
      # Store iteration
Guolin Ke's avatar
Guolin Ke committed
166
      i <- env$iteration
167
      
168
      # Check if iteration matches moduo
169
      if ((i - 1) %% period == 0 || is.element(i, c(env$begin_iteration, env$end_iteration ))) {
170
        
171
        # Merge evaluation string
172
        msg <- merge.eval.string(env)
173
        
174
175
176
177
        # Check if message is existing
        if (nchar(msg) > 0) {
          cat(merge.eval.string(env), "\n")
        }
178
        
Guolin Ke's avatar
Guolin Ke committed
179
      }
180
      
Guolin Ke's avatar
Guolin Ke committed
181
    }
182
    
Guolin Ke's avatar
Guolin Ke committed
183
  }
184
  
185
186
187
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.print.evaluation"
188
  
189
  # Return callback
190
  callback
191
  
Guolin Ke's avatar
Guolin Ke committed
192
193
194
}

cb.record.evaluation <- function() {
195
  
196
  # Create callback
197
  callback <- function(env) {
198
    
199
200
201
202
    # Return empty if empty evaluation list
    if (length(env$eval_list) <= 0) {
      return()
    }
203
    
204
    # Set if evaluation error
205
    is_eval_err <- length(env$eval_err_list) > 0
206
    
207
    # Check length of recorded evaluation
208
    if (length(env$model$record_evals) == 0) {
209
      
210
      # Loop through each evaluation list element
211
      for (j in seq_along(env$eval_list)) {
212
        
213
        # Store names
Guolin Ke's avatar
Guolin Ke committed
214
        data_name <- env$eval_list[[j]]$data_name
215
        name <- env$eval_list[[j]]$name
Guolin Ke's avatar
Guolin Ke committed
216
        env$model$record_evals$start_iter <- env$begin_iteration
217
        
218
        # Check if evaluation record exists
219
        if (is.null(env$model$record_evals[[data_name]])) {
Guolin Ke's avatar
Guolin Ke committed
220
221
          env$model$record_evals[[data_name]] <- list()
        }
222
        
223
224
225
        # Create dummy lists
        env$model$record_evals[[data_name]][[name]] <- list()
        env$model$record_evals[[data_name]][[name]]$eval <- list()
Guolin Ke's avatar
Guolin Ke committed
226
        env$model$record_evals[[data_name]][[name]]$eval_err <- list()
227
        
Guolin Ke's avatar
Guolin Ke committed
228
      }
229
      
Guolin Ke's avatar
Guolin Ke committed
230
    }
231
    
232
    # Loop through each evaluation list element
233
    for (j in seq_along(env$eval_list)) {
234
      
235
      # Get evaluation data
Guolin Ke's avatar
Guolin Ke committed
236
237
      eval_res <- env$eval_list[[j]]
      eval_err <- NULL
238
239
240
      if (is_eval_err) {
        eval_err <- env$eval_err_list[[j]]
      }
241
      
242
      # Store names
Guolin Ke's avatar
Guolin Ke committed
243
      data_name <- eval_res$data_name
244
      name <- eval_res$name
245
      
246
      # Store evaluation data
Guolin Ke's avatar
Guolin Ke committed
247
248
      env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value)
      env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err)
249
      
Guolin Ke's avatar
Guolin Ke committed
250
    }
251
    
Guolin Ke's avatar
Guolin Ke committed
252
  }
253
  
254
255
256
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.record.evaluation"
257
  
258
  # Return callback
259
  callback
260
  
Guolin Ke's avatar
Guolin Ke committed
261
262
}

263
cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
264
  
265
  # Initialize variables
Guolin Ke's avatar
Guolin Ke committed
266
  factor_to_bigger_better <- NULL
267
268
269
270
  best_iter <- NULL
  best_score <- NULL
  best_msg <- NULL
  eval_len <- NULL
271
  
272
  # Initalization function
Guolin Ke's avatar
Guolin Ke committed
273
  init <- function(env) {
274
    
275
    # Store evaluation length
276
    eval_len <<- length(env$eval_list)
277
    
278
    # Early stopping cannot work without metrics
279
    if (eval_len == 0) {
Guolin Ke's avatar
Guolin Ke committed
280
      stop("For early stopping, valids must have at least one element")
281
    }
282
    
283
    # Check if verbose or not
284
    if (isTRUE(verbose)) {
285
      cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "")
286
    }
287
    
288
    # Maximization or minimization task
289
290
291
    factor_to_bigger_better <<- rep.int(1.0, eval_len)
    best_iter <<- rep.int(-1, eval_len)
    best_score <<- rep.int(-Inf, eval_len)
292
    best_msg <<- list()
293
    
294
    # Loop through evaluation elements
295
    for (i in seq_len(eval_len)) {
296
      
297
      # Prepend message
Guolin Ke's avatar
Guolin Ke committed
298
      best_msg <<- c(best_msg, "")
299
      
300
      # Check if maximization or minimization
301
      if (!env$eval_list[[i]]$higher_better) {
Guolin Ke's avatar
Guolin Ke committed
302
303
        factor_to_bigger_better[i] <<- -1.0
      }
304
      
Guolin Ke's avatar
Guolin Ke committed
305
    }
306
    
Guolin Ke's avatar
Guolin Ke committed
307
  }
308
  
309
  # Create callback
Guolin Ke's avatar
Guolin Ke committed
310
  callback <- function(env, finalize = FALSE) {
311
    
312
313
314
315
    # Check for empty evaluation
    if (is.null(eval_len)) {
      init(env)
    }
316
    
317
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
318
    cur_iter <- env$iteration
319
    
320
    # Loop through evaluation
321
    for (i in seq_len(eval_len)) {
322
      
323
      # Store score
Guolin Ke's avatar
Guolin Ke committed
324
      score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
325
326
327
328
329
330
331
332
333
334
335
        
        # Check if score is better
        if (score > best_score[i]) {
          
          # Store new scores
          best_score[i] <<- score
          best_iter[i] <<- cur_iter
          
          # Prepare to print if verbose
          if (verbose) {
            best_msg[[i]] <<- as.character(merge.eval.string(env))
336
          }
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
          
        } else {
          
          # Check if early stopping is required
          if (cur_iter - best_iter[i] >= stopping_rounds) {
            
            # Check if model is not null
            if (!is.null(env$model)) {
              env$model$best_score <- best_score[i]
              env$model$best_iter <- best_iter[i]
            }
            
            # Print message if verbose
            if (isTRUE(verbose)) {
              
              cat("Early stopping, best iteration is:", "\n")
              cat(best_msg[[i]], "\n")
              
            }
            
            # Store best iteration and stop
            env$best_iter <- best_iter[i]
            env$met_early_stop <- TRUE
Guolin Ke's avatar
Guolin Ke committed
360
          }
361
          
Guolin Ke's avatar
Guolin Ke committed
362
        }
363
      
Guolin Ke's avatar
Guolin Ke committed
364
      if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) {
365
366
367
368
369
        # Check if model is not null
        if (!is.null(env$model)) {
          env$model$best_score <- best_score[i]
          env$model$best_iter <- best_iter[i]
        }
370
        
371
372
373
374
375
        # Print message if verbose
        if (isTRUE(verbose)) {
          cat("Did not meet early stopping, best iteration is:", "\n")
          cat(best_msg[[i]], "\n")
        }
376
        
377
378
379
380
        # Store best iteration and stop
        env$best_iter <- best_iter[i]
        env$met_early_stop <- TRUE
      }
Guolin Ke's avatar
Guolin Ke committed
381
382
    }
  }
383
  
384
385
386
  # Set attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.early.stop"
387
  
388
  # Return callback
389
  callback
390
  
Guolin Ke's avatar
Guolin Ke committed
391
392
393
}

# Extract callback names from the list of callbacks
394
callback.names <- function(cb_list) { unlist(lapply(cb_list, attr, "name")) }
Guolin Ke's avatar
Guolin Ke committed
395
396

add.cb <- function(cb_list, cb) {
397
  
398
  # Combine two elements
Guolin Ke's avatar
Guolin Ke committed
399
  cb_list <- c(cb_list, cb)
400
  
401
  # Set names of elements
Guolin Ke's avatar
Guolin Ke committed
402
  names(cb_list) <- callback.names(cb_list)
403
  
404
405
  # Check for existence
  if ("cb.early.stop" %in% names(cb_list)) {
406
    
407
408
    # Concatenate existing elements
    cb_list <- c(cb_list, cb_list["cb.early.stop"])
409
    
410
411
    # Remove only the first one
    cb_list["cb.early.stop"] <- NULL
412
    
Guolin Ke's avatar
Guolin Ke committed
413
  }
414
  
415
  # Return element
Guolin Ke's avatar
Guolin Ke committed
416
  cb_list
417
  
Guolin Ke's avatar
Guolin Ke committed
418
419
420
}

categorize.callbacks <- function(cb_list) {
421
  
422
  # Check for pre-iteration or post-iteration
Guolin Ke's avatar
Guolin Ke committed
423
424
  list(
    pre_iter = Filter(function(x) {
425
426
427
      pre <- attr(x, "is_pre_iteration")
      !is.null(pre) && pre
    }, cb_list),
Guolin Ke's avatar
Guolin Ke committed
428
    post_iter = Filter(function(x) {
429
430
431
      pre <- attr(x, "is_pre_iteration")
      is.null(pre) || !pre
    }, cb_list)
Guolin Ke's avatar
Guolin Ke committed
432
  )
433
  
Guolin Ke's avatar
Guolin Ke committed
434
}