callback.R 10.1 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
CB_ENV <- R6Class(
  "lgb.cb_env",
3
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
4
  public = list(
5
6
    model = NULL,
    iteration = NULL,
7
    begin_iteration = NULL,
8
9
10
11
    end_iteration = NULL,
    eval_list = list(),
    eval_err_list = list(),
    best_iter = -1,
Laurae's avatar
Laurae committed
12
    best_score = -1,
13
    met_early_stop = FALSE
Guolin Ke's avatar
Guolin Ke committed
14
15
16
17
  )
)

cb.reset.parameters <- function(new_params) {
18
19
20
21
22
23
24
  
  # Check for parameter list
  if (!is.list(new_params)) {
    stop(sQuote("new_params"), " must be a list")
  }
  
  # Deparse parameter list
25
  pnames  <- gsub("\\.", "_", names(new_params))
Guolin Ke's avatar
Guolin Ke committed
26
  nrounds <- NULL
27

28
  # Run some checks in the beginning
Guolin Ke's avatar
Guolin Ke committed
29
  init <- function(env) {
30
31
    
    # Store boosting rounds
Guolin Ke's avatar
Guolin Ke committed
32
    nrounds <<- env$end_iteration - env$begin_iteration + 1
33
34
    
    # Check for model environment
35
    if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) }
36
    
Guolin Ke's avatar
Guolin Ke committed
37
38
    # Some parameters are not allowed to be changed,
    # since changing them would simply wreck some chaos
39
40
41
42
    not_allowed <- c("num_class", "metric", "boosting_type")
    if (any(pnames %in% not_allowed)) {
      stop("Parameters ", paste0(pnames[pnames %in% not_allowed], collapse = ", "), " cannot be changed during boosting")
    }
43
44
    
    # Check parameter names
Guolin Ke's avatar
Guolin Ke committed
45
    for (n in pnames) {
46
47
      
      # Set name
Guolin Ke's avatar
Guolin Ke committed
48
      p <- new_params[[n]]
49
50
      
      # Check if function for parameter
Guolin Ke's avatar
Guolin Ke committed
51
      if (is.function(p)) {
52
53
54
        
        # Check if requires at least two arguments
        if (length(formals(p)) != 2) {
55
          stop("Parameter ", sQuote(n), " is a function but not of two arguments")
56
57
58
        }
        
        # Check if numeric or character
Guolin Ke's avatar
Guolin Ke committed
59
      } else if (is.numeric(p) || is.character(p)) {
60
61
62
        
        # Check if length is matching
        if (length(p) != nrounds) {
63
          stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds"))
64
65
        }
        
Guolin Ke's avatar
Guolin Ke committed
66
      } else {
67
        
68
        stop("Parameter ", sQuote(n), " is not a function or a vector")
69
        
Guolin Ke's avatar
Guolin Ke committed
70
      }
71
      
Guolin Ke's avatar
Guolin Ke committed
72
    }
73
    
Guolin Ke's avatar
Guolin Ke committed
74
  }
75

Guolin Ke's avatar
Guolin Ke committed
76
  callback <- function(env) {
77
78
79
80
81
82
83
    
    # Check if rounds is null
    if (is.null(nrounds)) {
      init(env)
    }
    
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
84
    i <- env$iteration - env$begin_iteration
85
86
    
    # Apply list on parameters
Guolin Ke's avatar
Guolin Ke committed
87
    pars <- lapply(new_params, function(p) {
88
89
90
      if (is.function(p)) {
        return(p(i, nrounds))
      }
Guolin Ke's avatar
Guolin Ke committed
91
92
      p[i]
    })
93
94
95
96
97
98
    
    # To-do check pars
    if (!is.null(env$model)) {
      env$model$reset_parameter(pars)
    }
    
Guolin Ke's avatar
Guolin Ke committed
99
  }
100
101
102
103
  
  attr(callback, "call") <- match.call()
  attr(callback, "is_pre_iteration") <- TRUE
  attr(callback, "name") <- "cb.reset.parameters"
104
  callback
Guolin Ke's avatar
Guolin Ke committed
105
106
107
}

# Format the evaluation metric string
108
format.eval.string <- function(eval_res, eval_err = NULL) {
109
110
111
112
113
114
115
  
  # Check for empty evaluation string
  if (is.null(eval_res) || length(eval_res) == 0) {
    stop("no evaluation results")
  }
  
  # Check for empty evaluation error
Guolin Ke's avatar
Guolin Ke committed
116
  if (!is.null(eval_err)) {
117
    sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err)
Guolin Ke's avatar
Guolin Ke committed
118
  } else {
119
    sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value)
Guolin Ke's avatar
Guolin Ke committed
120
  }
121
  
Guolin Ke's avatar
Guolin Ke committed
122
123
}

124
merge.eval.string <- function(env) {
125
126
127
128
129
130
131
132
133
134
  
  # Check length of evaluation list
  if (length(env$eval_list) <= 0) {
    return("")
  }
  
  # Get evaluation
  msg <- list(sprintf("[%d]:", env$iteration))
  
  # Set if evaluation error
Guolin Ke's avatar
Guolin Ke committed
135
  is_eval_err <- FALSE
136
137
138
139
140
141
142
  
  # Check evaluation error list length
  if (length(env$eval_err_list) > 0) {
    is_eval_err <- TRUE
  }
  
  # Loop through evaluation list
143
  for (j in seq_along(env$eval_list)) {
144
145
    
    # Store evaluation error
Guolin Ke's avatar
Guolin Ke committed
146
    eval_err <- NULL
147
148
149
150
151
    if (is_eval_err) {
      eval_err <- env$eval_err_list[[j]]
    }
    
    # Set error message
152
    msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err))
153
    
Guolin Ke's avatar
Guolin Ke committed
154
  }
155
156
157
158
  
  # Return tabulated separated message
  paste0(msg, collapse = "\t")
  
Guolin Ke's avatar
Guolin Ke committed
159
160
}

161
162
163
cb.print.evaluation <- function(period = 1) {
  
  # Create callback
164
  callback <- function(env) {
165
166
    
    # Check if period is at least 1 or more
167
    if (period > 0) {
168
169
      
      # Store iteration
Guolin Ke's avatar
Guolin Ke committed
170
      i <- env$iteration
171
172
173
174
175
      
      # Check if iteration matches moduo
      if ((i - 1) %% period == 0 | i == env$begin_iteration | i == env$end_iteration ) {
        
        # Merge evaluation string
176
        msg <- merge.eval.string(env)
177
178
179
180
181
182
        
        # Check if message is existing
        if (nchar(msg) > 0) {
          cat(merge.eval.string(env), "\n")
        }
        
Guolin Ke's avatar
Guolin Ke committed
183
      }
184
      
Guolin Ke's avatar
Guolin Ke committed
185
    }
186
    
Guolin Ke's avatar
Guolin Ke committed
187
  }
188
189
190
191
192
193
  
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.print.evaluation"
  
  # Return callback
194
  callback
195
  
Guolin Ke's avatar
Guolin Ke committed
196
197
198
}

cb.record.evaluation <- function() {
199
200
  
  # Create callback
201
  callback <- function(env) {
202
203
204
205
206
207
208
    
    # Return empty if empty evaluation list
    if (length(env$eval_list) <= 0) {
      return()
    }
    
    # Set if evaluation error
Guolin Ke's avatar
Guolin Ke committed
209
    is_eval_err <- FALSE
210
211
212
213
214
215
216
    
    # Check evaluation error list length
    if (length(env$eval_err_list) > 0) {
      is_eval_err <- TRUE
    }
    
    # Check length of recorded evaluation
217
    if (length(env$model$record_evals) == 0) {
218
219
      
      # Loop through each evaluation list element
220
      for (j in seq_along(env$eval_list)) {
221
222
        
        # Store names
Guolin Ke's avatar
Guolin Ke committed
223
        data_name <- env$eval_list[[j]]$data_name
224
        name <- env$eval_list[[j]]$name
Guolin Ke's avatar
Guolin Ke committed
225
        env$model$record_evals$start_iter <- env$begin_iteration
226
227
        
        # Check if evaluation record exists
228
        if (is.null(env$model$record_evals[[data_name]])) {
Guolin Ke's avatar
Guolin Ke committed
229
230
          env$model$record_evals[[data_name]] <- list()
        }
231
232
233
234
        
        # Create dummy lists
        env$model$record_evals[[data_name]][[name]] <- list()
        env$model$record_evals[[data_name]][[name]]$eval <- list()
Guolin Ke's avatar
Guolin Ke committed
235
        env$model$record_evals[[data_name]][[name]]$eval_err <- list()
236
        
Guolin Ke's avatar
Guolin Ke committed
237
      }
238
      
Guolin Ke's avatar
Guolin Ke committed
239
    }
240
241
    
    # Loop through each evaluation list element
242
    for (j in seq_along(env$eval_list)) {
243
244
      
      # Get evaluation data
Guolin Ke's avatar
Guolin Ke committed
245
246
      eval_res <- env$eval_list[[j]]
      eval_err <- NULL
247
248
249
250
251
      if (is_eval_err) {
        eval_err <- env$eval_err_list[[j]]
      }
      
      # Store names
Guolin Ke's avatar
Guolin Ke committed
252
      data_name <- eval_res$data_name
253
254
255
      name <- eval_res$name
      
      # Store evaluation data
Guolin Ke's avatar
Guolin Ke committed
256
257
      env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value)
      env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err)
258
      
Guolin Ke's avatar
Guolin Ke committed
259
    }
260
    
Guolin Ke's avatar
Guolin Ke committed
261
  }
262
263
264
265
266
267
  
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.record.evaluation"
  
  # Return callback
268
  callback
269
  
Guolin Ke's avatar
Guolin Ke committed
270
271
}

272
cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
273
274
  
  # Initialize variables
Guolin Ke's avatar
Guolin Ke committed
275
  factor_to_bigger_better <- NULL
276
277
278
279
280
281
  best_iter <- NULL
  best_score <- NULL
  best_msg <- NULL
  eval_len <- NULL
  
  # Initalization function
Guolin Ke's avatar
Guolin Ke committed
282
  init <- function(env) {
283
284
    
    # Store evaluation length
285
    eval_len <<- length(env$eval_list)
286
287
    
    # Early stopping cannot work without metrics
288
    if (eval_len == 0) {
Guolin Ke's avatar
Guolin Ke committed
289
      stop("For early stopping, valids must have at least one element")
290
    }
291
292
    
    # Check if verbose or not
293
    if (isTRUE(verbose)) {
294
      cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "")
295
    }
296
297
    
    # Maximization or minimization task
Guolin Ke's avatar
Guolin Ke committed
298
    factor_to_bigger_better <<- rep(1.0, eval_len)
299
    best_iter <<- rep(-1, eval_len)
Guolin Ke's avatar
Guolin Ke committed
300
    best_score <<- rep(-Inf, eval_len)
301
302
303
    best_msg <<- list()
    
    # Loop through evaluation elements
304
    for (i in seq_len(eval_len)) {
305
306
      
      # Prepend message
Guolin Ke's avatar
Guolin Ke committed
307
      best_msg <<- c(best_msg, "")
308
309
      
      # Check if maximization or minimization
310
      if (!env$eval_list[[i]]$higher_better) {
Guolin Ke's avatar
Guolin Ke committed
311
312
        factor_to_bigger_better[i] <<- -1.0
      }
313
      
Guolin Ke's avatar
Guolin Ke committed
314
    }
315
    
Guolin Ke's avatar
Guolin Ke committed
316
  }
317
318
  
  # Create callback
Guolin Ke's avatar
Guolin Ke committed
319
  callback <- function(env, finalize = FALSE) {
320
321
322
323
324
325
326
    
    # Check for empty evaluation
    if (is.null(eval_len)) {
      init(env)
    }
    
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
327
    cur_iter <- env$iteration
328
329
    
    # Loop through evaluation
330
    for (i in seq_len(eval_len)) {
331
332
      
      # Store score
Guolin Ke's avatar
Guolin Ke committed
333
      score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
334
335
      
      # Check if score is better
336
      if (score > best_score[i]) {
337
338
        
        # Store new scores
Guolin Ke's avatar
Guolin Ke committed
339
        best_score[i] <<- score
340
341
342
        best_iter[i] <<- cur_iter
        
        # Prepare to print if verbose
343
        if (verbose) {
Guolin Ke's avatar
Guolin Ke committed
344
345
          best_msg[[i]] <<- as.character(merge.eval.string(env))
        }
346
        
Guolin Ke's avatar
Guolin Ke committed
347
      } else {
348
349
        
        # Check if early stopping is required
350
        if (cur_iter - best_iter[i] >= stopping_rounds) {
351
352
353
          
          # Check if model is not null
          if (!is.null(env$model)) {
Laurae's avatar
Laurae committed
354
            env$model$best_score <- best_score[i]
355
356
357
358
            env$model$best_iter <- best_iter[i]
          }
          
          # Print message if verbose
359
          if (isTRUE(verbose)) {
360
            
361
362
            cat("Early stopping, best iteration is:", "\n")
            cat(best_msg[[i]], "\n")
363
            
Guolin Ke's avatar
Guolin Ke committed
364
          }
365
366
367
          
          # Store best iteration and stop
          env$best_iter <- best_iter[i]
Guolin Ke's avatar
Guolin Ke committed
368
          env$met_early_stop <- TRUE
369
          
Guolin Ke's avatar
Guolin Ke committed
370
        }
371
        
Guolin Ke's avatar
Guolin Ke committed
372
      }
373
      
Guolin Ke's avatar
Guolin Ke committed
374
    }
375
    
Guolin Ke's avatar
Guolin Ke committed
376
  }
377
378
379
380
381
382
  
  # Set attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.early.stop"
  
  # Return callback
383
  callback
384
  
Guolin Ke's avatar
Guolin Ke committed
385
386
387
}

# Extract callback names from the list of callbacks
388
callback.names <- function(cb_list) { unlist(lapply(cb_list, attr, "name")) }
Guolin Ke's avatar
Guolin Ke committed
389
390

add.cb <- function(cb_list, cb) {
391
392
  
  # Combine two elements
Guolin Ke's avatar
Guolin Ke committed
393
  cb_list <- c(cb_list, cb)
394
395
  
  # Set names of elements
Guolin Ke's avatar
Guolin Ke committed
396
  names(cb_list) <- callback.names(cb_list)
397
398
399
400
401
402
403
404
405
406
  
  # Check for existence
  if ("cb.early.stop" %in% names(cb_list)) {
    
    # Concatenate existing elements
    cb_list <- c(cb_list, cb_list["cb.early.stop"])
    
    # Remove only the first one
    cb_list["cb.early.stop"] <- NULL
    
Guolin Ke's avatar
Guolin Ke committed
407
  }
408
409
  
  # Return element
Guolin Ke's avatar
Guolin Ke committed
410
  cb_list
411
  
Guolin Ke's avatar
Guolin Ke committed
412
413
414
}

categorize.callbacks <- function(cb_list) {
415
416
  
  # Check for pre-iteration or post-iteration
Guolin Ke's avatar
Guolin Ke committed
417
418
  list(
    pre_iter = Filter(function(x) {
419
        pre <- attr(x, "is_pre_iteration")
420
        !is.null(pre) && pre
Guolin Ke's avatar
Guolin Ke committed
421
422
      }, cb_list),
    post_iter = Filter(function(x) {
423
        pre <- attr(x, "is_pre_iteration")
Guolin Ke's avatar
Guolin Ke committed
424
425
426
        is.null(pre) || !pre
      }, cb_list)
  )
427
  
Guolin Ke's avatar
Guolin Ke committed
428
}