callback.R 10 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
CB_ENV <- R6Class(
  "lgb.cb_env",
3
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
4
  public = list(
5
6
    model = NULL,
    iteration = NULL,
7
    begin_iteration = NULL,
8
9
10
11
12
    end_iteration = NULL,
    eval_list = list(),
    eval_err_list = list(),
    best_iter = -1,
    met_early_stop = FALSE
Guolin Ke's avatar
Guolin Ke committed
13
14
15
16
  )
)

cb.reset.parameters <- function(new_params) {
17
18
19
20
21
22
23
  
  # Check for parameter list
  if (!is.list(new_params)) {
    stop(sQuote("new_params"), " must be a list")
  }
  
  # Deparse parameter list
24
  pnames  <- gsub("\\.", "_", names(new_params))
Guolin Ke's avatar
Guolin Ke committed
25
  nrounds <- NULL
26

27
  # Run some checks in the beginning
Guolin Ke's avatar
Guolin Ke committed
28
  init <- function(env) {
29
30
    
    # Store boosting rounds
Guolin Ke's avatar
Guolin Ke committed
31
    nrounds <<- env$end_iteration - env$begin_iteration + 1
32
33
    
    # Check for model environment
34
    if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) }
35
    
Guolin Ke's avatar
Guolin Ke committed
36
37
    # Some parameters are not allowed to be changed,
    # since changing them would simply wreck some chaos
38
39
40
41
    not_allowed <- c("num_class", "metric", "boosting_type")
    if (any(pnames %in% not_allowed)) {
      stop("Parameters ", paste0(pnames[pnames %in% not_allowed], collapse = ", "), " cannot be changed during boosting")
    }
42
43
    
    # Check parameter names
Guolin Ke's avatar
Guolin Ke committed
44
    for (n in pnames) {
45
46
      
      # Set name
Guolin Ke's avatar
Guolin Ke committed
47
      p <- new_params[[n]]
48
49
      
      # Check if function for parameter
Guolin Ke's avatar
Guolin Ke committed
50
      if (is.function(p)) {
51
52
53
        
        # Check if requires at least two arguments
        if (length(formals(p)) != 2) {
54
          stop("Parameter ", sQuote(n), " is a function but not of two arguments")
55
56
57
        }
        
        # Check if numeric or character
Guolin Ke's avatar
Guolin Ke committed
58
      } else if (is.numeric(p) || is.character(p)) {
59
60
61
        
        # Check if length is matching
        if (length(p) != nrounds) {
62
          stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds"))
63
64
        }
        
Guolin Ke's avatar
Guolin Ke committed
65
      } else {
66
        
67
        stop("Parameter ", sQuote(n), " is not a function or a vector")
68
        
Guolin Ke's avatar
Guolin Ke committed
69
      }
70
      
Guolin Ke's avatar
Guolin Ke committed
71
    }
72
    
Guolin Ke's avatar
Guolin Ke committed
73
  }
74

Guolin Ke's avatar
Guolin Ke committed
75
  callback <- function(env) {
76
77
78
79
80
81
82
    
    # Check if rounds is null
    if (is.null(nrounds)) {
      init(env)
    }
    
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
83
    i <- env$iteration - env$begin_iteration
84
85
    
    # Apply list on parameters
Guolin Ke's avatar
Guolin Ke committed
86
    pars <- lapply(new_params, function(p) {
87
88
89
      if (is.function(p)) {
        return(p(i, nrounds))
      }
Guolin Ke's avatar
Guolin Ke committed
90
91
      p[i]
    })
92
93
94
95
96
97
    
    # To-do check pars
    if (!is.null(env$model)) {
      env$model$reset_parameter(pars)
    }
    
Guolin Ke's avatar
Guolin Ke committed
98
  }
99
100
101
102
  
  attr(callback, "call") <- match.call()
  attr(callback, "is_pre_iteration") <- TRUE
  attr(callback, "name") <- "cb.reset.parameters"
103
  callback
Guolin Ke's avatar
Guolin Ke committed
104
105
106
}

# Format the evaluation metric string
107
format.eval.string <- function(eval_res, eval_err = NULL) {
108
109
110
111
112
113
114
  
  # Check for empty evaluation string
  if (is.null(eval_res) || length(eval_res) == 0) {
    stop("no evaluation results")
  }
  
  # Check for empty evaluation error
Guolin Ke's avatar
Guolin Ke committed
115
  if (!is.null(eval_err)) {
116
    sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err)
Guolin Ke's avatar
Guolin Ke committed
117
  } else {
118
    sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value)
Guolin Ke's avatar
Guolin Ke committed
119
  }
120
  
Guolin Ke's avatar
Guolin Ke committed
121
122
}

123
merge.eval.string <- function(env) {
124
125
126
127
128
129
130
131
132
133
  
  # Check length of evaluation list
  if (length(env$eval_list) <= 0) {
    return("")
  }
  
  # Get evaluation
  msg <- list(sprintf("[%d]:", env$iteration))
  
  # Set if evaluation error
Guolin Ke's avatar
Guolin Ke committed
134
  is_eval_err <- FALSE
135
136
137
138
139
140
141
  
  # Check evaluation error list length
  if (length(env$eval_err_list) > 0) {
    is_eval_err <- TRUE
  }
  
  # Loop through evaluation list
142
  for (j in seq_along(env$eval_list)) {
143
144
    
    # Store evaluation error
Guolin Ke's avatar
Guolin Ke committed
145
    eval_err <- NULL
146
147
148
149
150
    if (is_eval_err) {
      eval_err <- env$eval_err_list[[j]]
    }
    
    # Set error message
151
    msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err))
152
    
Guolin Ke's avatar
Guolin Ke committed
153
  }
154
155
156
157
  
  # Return tabulated separated message
  paste0(msg, collapse = "\t")
  
Guolin Ke's avatar
Guolin Ke committed
158
159
}

160
161
162
cb.print.evaluation <- function(period = 1) {
  
  # Create callback
163
  callback <- function(env) {
164
165
    
    # Check if period is at least 1 or more
166
    if (period > 0) {
167
168
      
      # Store iteration
Guolin Ke's avatar
Guolin Ke committed
169
      i <- env$iteration
170
171
172
173
174
      
      # Check if iteration matches moduo
      if ((i - 1) %% period == 0 | i == env$begin_iteration | i == env$end_iteration ) {
        
        # Merge evaluation string
175
        msg <- merge.eval.string(env)
176
177
178
179
180
181
        
        # Check if message is existing
        if (nchar(msg) > 0) {
          cat(merge.eval.string(env), "\n")
        }
        
Guolin Ke's avatar
Guolin Ke committed
182
      }
183
      
Guolin Ke's avatar
Guolin Ke committed
184
    }
185
    
Guolin Ke's avatar
Guolin Ke committed
186
  }
187
188
189
190
191
192
  
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.print.evaluation"
  
  # Return callback
193
  callback
194
  
Guolin Ke's avatar
Guolin Ke committed
195
196
197
}

cb.record.evaluation <- function() {
198
199
  
  # Create callback
200
  callback <- function(env) {
201
202
203
204
205
206
207
    
    # Return empty if empty evaluation list
    if (length(env$eval_list) <= 0) {
      return()
    }
    
    # Set if evaluation error
Guolin Ke's avatar
Guolin Ke committed
208
    is_eval_err <- FALSE
209
210
211
212
213
214
215
    
    # Check evaluation error list length
    if (length(env$eval_err_list) > 0) {
      is_eval_err <- TRUE
    }
    
    # Check length of recorded evaluation
216
    if (length(env$model$record_evals) == 0) {
217
218
      
      # Loop through each evaluation list element
219
      for (j in seq_along(env$eval_list)) {
220
221
        
        # Store names
Guolin Ke's avatar
Guolin Ke committed
222
        data_name <- env$eval_list[[j]]$data_name
223
        name <- env$eval_list[[j]]$name
Guolin Ke's avatar
Guolin Ke committed
224
        env$model$record_evals$start_iter <- env$begin_iteration
225
226
        
        # Check if evaluation record exists
227
        if (is.null(env$model$record_evals[[data_name]])) {
Guolin Ke's avatar
Guolin Ke committed
228
229
          env$model$record_evals[[data_name]] <- list()
        }
230
231
232
233
        
        # Create dummy lists
        env$model$record_evals[[data_name]][[name]] <- list()
        env$model$record_evals[[data_name]][[name]]$eval <- list()
Guolin Ke's avatar
Guolin Ke committed
234
        env$model$record_evals[[data_name]][[name]]$eval_err <- list()
235
        
Guolin Ke's avatar
Guolin Ke committed
236
      }
237
      
Guolin Ke's avatar
Guolin Ke committed
238
    }
239
240
    
    # Loop through each evaluation list element
241
    for (j in seq_along(env$eval_list)) {
242
243
      
      # Get evaluation data
Guolin Ke's avatar
Guolin Ke committed
244
245
      eval_res <- env$eval_list[[j]]
      eval_err <- NULL
246
247
248
249
250
      if (is_eval_err) {
        eval_err <- env$eval_err_list[[j]]
      }
      
      # Store names
Guolin Ke's avatar
Guolin Ke committed
251
      data_name <- eval_res$data_name
252
253
254
      name <- eval_res$name
      
      # Store evaluation data
Guolin Ke's avatar
Guolin Ke committed
255
256
      env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value)
      env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err)
257
      
Guolin Ke's avatar
Guolin Ke committed
258
    }
259
    
Guolin Ke's avatar
Guolin Ke committed
260
  }
261
262
263
264
265
266
  
  # Store attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.record.evaluation"
  
  # Return callback
267
  callback
268
  
Guolin Ke's avatar
Guolin Ke committed
269
270
}

271
cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
272
273
  
  # Initialize variables
Guolin Ke's avatar
Guolin Ke committed
274
  factor_to_bigger_better <- NULL
275
276
277
278
279
280
  best_iter <- NULL
  best_score <- NULL
  best_msg <- NULL
  eval_len <- NULL
  
  # Initalization function
Guolin Ke's avatar
Guolin Ke committed
281
  init <- function(env) {
282
283
    
    # Store evaluation length
284
    eval_len <<- length(env$eval_list)
285
286
    
    # Early stopping cannot work without metrics
287
    if (eval_len == 0) {
Guolin Ke's avatar
Guolin Ke committed
288
      stop("For early stopping, valids must have at least one element")
289
    }
290
291
    
    # Check if verbose or not
292
    if (isTRUE(verbose)) {
293
      cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "")
294
    }
295
296
    
    # Maximization or minimization task
Guolin Ke's avatar
Guolin Ke committed
297
    factor_to_bigger_better <<- rep(1.0, eval_len)
298
    best_iter <<- rep(-1, eval_len)
Guolin Ke's avatar
Guolin Ke committed
299
    best_score <<- rep(-Inf, eval_len)
300
301
302
    best_msg <<- list()
    
    # Loop through evaluation elements
303
    for (i in seq_len(eval_len)) {
304
305
      
      # Prepend message
Guolin Ke's avatar
Guolin Ke committed
306
      best_msg <<- c(best_msg, "")
307
308
      
      # Check if maximization or minimization
309
      if (!env$eval_list[[i]]$higher_better) {
Guolin Ke's avatar
Guolin Ke committed
310
311
        factor_to_bigger_better[i] <<- -1.0
      }
312
      
Guolin Ke's avatar
Guolin Ke committed
313
    }
314
    
Guolin Ke's avatar
Guolin Ke committed
315
  }
316
317
  
  # Create callback
Guolin Ke's avatar
Guolin Ke committed
318
  callback <- function(env, finalize = FALSE) {
319
320
321
322
323
324
325
    
    # Check for empty evaluation
    if (is.null(eval_len)) {
      init(env)
    }
    
    # Store iteration
Guolin Ke's avatar
Guolin Ke committed
326
    cur_iter <- env$iteration
327
328
    
    # Loop through evaluation
329
    for (i in seq_len(eval_len)) {
330
331
      
      # Store score
Guolin Ke's avatar
Guolin Ke committed
332
      score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
333
334
      
      # Check if score is better
335
      if (score > best_score[i]) {
336
337
        
        # Store new scores
Guolin Ke's avatar
Guolin Ke committed
338
        best_score[i] <<- score
339
340
341
        best_iter[i] <<- cur_iter
        
        # Prepare to print if verbose
342
        if (verbose) {
Guolin Ke's avatar
Guolin Ke committed
343
344
          best_msg[[i]] <<- as.character(merge.eval.string(env))
        }
345
        
Guolin Ke's avatar
Guolin Ke committed
346
      } else {
347
348
        
        # Check if early stopping is required
349
        if (cur_iter - best_iter[i] >= stopping_rounds) {
350
351
352
353
354
355
356
          
          # Check if model is not null
          if (!is.null(env$model)) {
            env$model$best_iter <- best_iter[i]
          }
          
          # Print message if verbose
357
          if (isTRUE(verbose)) {
358
            
359
360
            cat("Early stopping, best iteration is:", "\n")
            cat(best_msg[[i]], "\n")
361
            
Guolin Ke's avatar
Guolin Ke committed
362
          }
363
364
365
          
          # Store best iteration and stop
          env$best_iter <- best_iter[i]
Guolin Ke's avatar
Guolin Ke committed
366
          env$met_early_stop <- TRUE
367
          
Guolin Ke's avatar
Guolin Ke committed
368
        }
369
        
Guolin Ke's avatar
Guolin Ke committed
370
      }
371
      
Guolin Ke's avatar
Guolin Ke committed
372
    }
373
    
Guolin Ke's avatar
Guolin Ke committed
374
  }
375
376
377
378
379
380
  
  # Set attributes
  attr(callback, "call") <- match.call()
  attr(callback, "name") <- "cb.early.stop"
  
  # Return callback
381
  callback
382
  
Guolin Ke's avatar
Guolin Ke committed
383
384
385
}

# Extract callback names from the list of callbacks
386
callback.names <- function(cb_list) { unlist(lapply(cb_list, attr, "name")) }
Guolin Ke's avatar
Guolin Ke committed
387
388

add.cb <- function(cb_list, cb) {
389
390
  
  # Combine two elements
Guolin Ke's avatar
Guolin Ke committed
391
  cb_list <- c(cb_list, cb)
392
393
  
  # Set names of elements
Guolin Ke's avatar
Guolin Ke committed
394
  names(cb_list) <- callback.names(cb_list)
395
396
397
398
399
400
401
402
403
404
  
  # Check for existence
  if ("cb.early.stop" %in% names(cb_list)) {
    
    # Concatenate existing elements
    cb_list <- c(cb_list, cb_list["cb.early.stop"])
    
    # Remove only the first one
    cb_list["cb.early.stop"] <- NULL
    
Guolin Ke's avatar
Guolin Ke committed
405
  }
406
407
  
  # Return element
Guolin Ke's avatar
Guolin Ke committed
408
  cb_list
409
  
Guolin Ke's avatar
Guolin Ke committed
410
411
412
}

categorize.callbacks <- function(cb_list) {
413
414
  
  # Check for pre-iteration or post-iteration
Guolin Ke's avatar
Guolin Ke committed
415
416
  list(
    pre_iter = Filter(function(x) {
417
        pre <- attr(x, "is_pre_iteration")
418
        !is.null(pre) && pre
Guolin Ke's avatar
Guolin Ke committed
419
420
      }, cb_list),
    post_iter = Filter(function(x) {
421
        pre <- attr(x, "is_pre_iteration")
Guolin Ke's avatar
Guolin Ke committed
422
423
424
        is.null(pre) || !pre
      }, cb_list)
  )
425
  
Guolin Ke's avatar
Guolin Ke committed
426
}