lgb.Booster.R 24.5 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
Booster <- R6Class(
2
  classname = "lgb.Booster",
3
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
4
  public = list(
5
6
    
    best_iter = -1,
Guolin Ke's avatar
Guolin Ke committed
7
    record_evals = list(),
8
9
10
11
12
    
    # Finalize will free up the handles
    finalize = function() {
      
      # Check the need for freeing handle
13
      if (!lgb.is.null.handle(private$handle)) {
14
15
        
        # Freeing up handle
16
        lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
Guolin Ke's avatar
Guolin Ke committed
17
        private$handle <- NULL
18
        
Guolin Ke's avatar
Guolin Ke committed
19
      }
20
      
21
    },
22
23
24
    
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
25
26
27
                          train_set = NULL,
                          modelfile = NULL,
                          ...) {
28
29
30
      
      # Create parameters and handle
      params <- append(params, list(...))
Guolin Ke's avatar
Guolin Ke committed
31
      params_str <- lgb.params2str(params)
32
33
34
      handle <- lgb.new.handle()
      
      # Check if training dataset is not null
Guolin Ke's avatar
Guolin Ke committed
35
      if (!is.null(train_set)) {
36
37
        
        # Check if training dataset is lgb.Dataset or not
Guolin Ke's avatar
Guolin Ke committed
38
        if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
39
          stop("lgb.Booster: Can only use lgb.Dataset as training data")
Guolin Ke's avatar
Guolin Ke committed
40
        }
41
42
43
44
45
46
47
        
        # Store booster handle
        handle <- lgb.call("LGBM_BoosterCreate_R", ret = handle, train_set$.__enclos_env__$private$get_handle(), params_str)
        
        # Create private booster information
        private$train_set <- train_set
        private$num_dataset <- 1
Guolin Ke's avatar
Guolin Ke committed
48
        private$init_predictor <- train_set$.__enclos_env__$private$predictor
49
50
        
        # Check if predictor is existing
Guolin Ke's avatar
Guolin Ke committed
51
        if (!is.null(private$init_predictor)) {
52
53
54
55
56
57
58
          
          # Merge booster
          lgb.call("LGBM_BoosterMerge_R",
                   ret = NULL,
                   handle,
                   private$init_predictor$.__enclos_env__$private$handle)
          
Guolin Ke's avatar
Guolin Ke committed
59
        }
60
61
        
        # Check current iteration
62
        private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
63
        
Guolin Ke's avatar
Guolin Ke committed
64
      } else if (!is.null(modelfile)) {
65
66
        
        # Do we have a model file as character?
Guolin Ke's avatar
Guolin Ke committed
67
        if (!is.character(modelfile)) {
68
          stop("lgb.Booster: Can only use a string as model file path")
Guolin Ke's avatar
Guolin Ke committed
69
        }
70
71
72
73
74
75
        
        # Create booster from model
        handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R",
                           ret = handle,
                           lgb.c_str(modelfile))
        
Guolin Ke's avatar
Guolin Ke committed
76
      } else {
77
78
79
80
        
        # Booster non existent
        stop("lgb.Booster: Need at least either training dataset or model file to create booster instance")
        
Guolin Ke's avatar
Guolin Ke committed
81
      }
82
83
84
85
      
      # Create class
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
86
      private$num_class <- 1L
87
88
89
90
      private$num_class <- lgb.call("LGBM_BoosterGetNumClasses_R",
                                    ret = private$num_class,
                                    private$handle)
      
Guolin Ke's avatar
Guolin Ke committed
91
    },
92
93
    
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
94
    set_train_data_name = function(name) {
95
96
      
      # Set name
Guolin Ke's avatar
Guolin Ke committed
97
      private$name_train_set <- name
98
      self
99
      
Guolin Ke's avatar
Guolin Ke committed
100
    },
101
102
    
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
103
    add_valid = function(data, name) {
104
105
      
      # Check if data is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
106
      if (!lgb.check.r6.class(data, "lgb.Dataset")) {
107
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
108
      }
109
110
      
      # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
111
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
112
        stop("lgb.Booster.add_valid: Failed to add validation data; you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
113
      }
114
115
      
      # Check if names are character
116
117
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
118
      }
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
      
      # Add validation data to booster
      lgb.call("LGBM_BoosterAddValidData_R",
               ret = NULL,
               private$handle,
               data$.__enclos_env__$private$get_handle())
      
      # Store private information
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
      private$num_dataset <- private$num_dataset + 1
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
      
      # Return self
      return(self)
      
Guolin Ke's avatar
Guolin Ke committed
135
    },
136
137
    
    # Reset parameters of booster
Guolin Ke's avatar
Guolin Ke committed
138
    reset_parameter = function(params, ...) {
139
140
141
      
      # Append parameters
      params <- append(params, list(...))
Guolin Ke's avatar
Guolin Ke committed
142
      params_str <- algb.params2str(params)
143
144
145
146
147
148
149
150
151
152
      
      # Reset parameters
      lgb.call("LGBM_BoosterResetParameter_R",
               ret = NULL,
               private$handle,
               params_str)
      
      # Return self
      return(self)
      
Guolin Ke's avatar
Guolin Ke committed
153
    },
154
155
    
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
156
    update = function(train_set = NULL, fobj = NULL) {
157
158
      
      # Check if training set is not null
Guolin Ke's avatar
Guolin Ke committed
159
      if (!is.null(train_set)) {
160
161
        
        # Check if training set is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
162
163
164
        if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
165
166
        
        # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
167
        if (!identical(train_set$predictor, private$init_predictor)) {
168
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
169
        }
170
171
172
173
174
175
176
177
        
        # Reset training data on booster
        lgb.call("LGBM_BoosterResetTrainingData_R",
                 ret = NULL,
                 private$handle,
                 train_set$.__enclos_env__$private$get_handle())
        
        # Store private train set
Guolin Ke's avatar
Guolin Ke committed
178
        private$train_set = train_set
179
        
Guolin Ke's avatar
Guolin Ke committed
180
      }
181
182
      
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
183
      if (is.null(fobj)) {
184
185
        
        # Boost iteration from known objective
186
        ret <- lgb.call("LGBM_BoosterUpdateOneIter_R", ret = NULL, private$handle)
187
        
Guolin Ke's avatar
Guolin Ke committed
188
      } else {
189
190
191
192
193
194
195
        
        # Check if objective is function
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
        
        # Perform objective calculation
Guolin Ke's avatar
Guolin Ke committed
196
        gpair <- fobj(private$inner_predict(1), private$train_set)
197
198
        
        # Check for gradient and hessian as list
199
200
201
202
        if(is.null(gpair$grad) | is.null(gpair$hess)){
          stop("lgb.Booster.update: custom objective should 
            return a list with attributes (hess, grad)")
        }
203
204
205
206
207
208
209
210
211
        
        # Return custom boosting gradient/hessian
        ret <- lgb.call("LGBM_BoosterUpdateOneIterCustom_R",
                        ret = NULL,
                        private$handle,
                        gpair$grad,
                        gpair$hess,
                        length(gpair$grad))
        
Guolin Ke's avatar
Guolin Ke committed
212
      }
213
214
      
      # Loop through each iteration
215
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
216
217
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
218
219
220
221
      
      # Return self
      return(ret)
      
Guolin Ke's avatar
Guolin Ke committed
222
    },
223
224
    
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
225
    rollback_one_iter = function() {
226
227
228
229
230
231
232
      
      # Return one iteration behind
      lgb.call("LGBM_BoosterRollbackOneIter_R",
               ret = NULL,
               private$handle)
      
      # Loop through each iteration
233
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
234
235
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
236
237
238
239
      
      # Return self
      return(self)
      
Guolin Ke's avatar
Guolin Ke committed
240
    },
241
242
    
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
243
    current_iter = function() {
244
      
245
      cur_iter <- 0L
246
247
248
249
      lgb.call("LGBM_BoosterGetCurrentIteration_R",
               ret = cur_iter,
               private$handle)
      
Guolin Ke's avatar
Guolin Ke committed
250
    },
251
252
    
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
253
    eval = function(data, name, feval = NULL) {
254
255
      
      # Check if dataset is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
256
      if (!lgb.check.r6.class(data, "lgb.Dataset")) {
257
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
258
      }
259
260
      
      # Check for identical data
Guolin Ke's avatar
Guolin Ke committed
261
      data_idx <- 0
262
263
264
265
266
      if (identical(data, private$train_set)) {
        data_idx <- 1
      } else {
        
        # Check for validation data
267
        if (length(private$valid_sets) > 0) {
268
269
          
          # Loop through each validation set
270
          for (i in seq_along(private$valid_sets)) {
271
272
            
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
273
            if (identical(data, private$valid_sets[[i]])) {
274
275
              
              # Found identical data, skip
Guolin Ke's avatar
Guolin Ke committed
276
277
              data_idx <- i + 1
              break
278
              
Guolin Ke's avatar
Guolin Ke committed
279
            }
280
            
Guolin Ke's avatar
Guolin Ke committed
281
          }
282
          
Guolin Ke's avatar
Guolin Ke committed
283
        }
284
        
Guolin Ke's avatar
Guolin Ke committed
285
      }
286
287
      
      # Check if evaluation was not done
Guolin Ke's avatar
Guolin Ke committed
288
      if (data_idx == 0) {
289
290
        
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
291
292
        self$add_valid(data, name)
        data_idx <- private$num_dataset
293
        
Guolin Ke's avatar
Guolin Ke committed
294
      }
295
296
      
      # Evaluate data
297
      private$inner_eval(name, data_idx, feval)
298
      
Guolin Ke's avatar
Guolin Ke committed
299
    },
300
301
    
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
302
    eval_train = function(feval = NULL) {
303
      private$inner_eval(private$name_train_set, 1, feval)
Guolin Ke's avatar
Guolin Ke committed
304
    },
305
306
    
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
307
    eval_valid = function(feval = NULL) {
308
309
      
      # Create ret list
Guolin Ke's avatar
Guolin Ke committed
310
      ret = list()
311
312
313
314
315
316
317
      
      # Check if validation is empty
      if (length(private$valid_sets) <= 0) {
        return(ret)
      }
      
      # Loop through each validation set
318
319
      for (i in seq_along(private$valid_sets)) {
        ret <- append(ret, private$inner_eval(private$name_valid_sets[[i]], i + 1, feval))
Guolin Ke's avatar
Guolin Ke committed
320
      }
321
322
323
324
      
      # Return ret
      return(ret)
      
Guolin Ke's avatar
Guolin Ke committed
325
    },
326
327
    
    # Save model
Guolin Ke's avatar
Guolin Ke committed
328
    save_model = function(filename, num_iteration = NULL) {
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
      
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
      
      # Save booster model
      lgb.call("LGBM_BoosterSaveModel_R",
               ret = NULL,
               private$handle,
               as.integer(num_iteration),
               lgb.c_str(filename))
      
      # Return self
      return(self)
Guolin Ke's avatar
Guolin Ke committed
344
    },
345
346
    
    # Dump model in memory
Guolin Ke's avatar
Guolin Ke committed
347
    dump_model = function(num_iteration = NULL) {
348
349
350
351
352
353
354
355
356
357
358
      
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
      
      # Return dumped model
      lgb.call.return.str("LGBM_BoosterDumpModel_R",
                          private$handle,
                          as.integer(num_iteration))
      
Guolin Ke's avatar
Guolin Ke committed
359
    },
360
361
    
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
362
    predict = function(data,
363
364
365
366
367
368
369
370
371
372
373
374
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
                       header = FALSE,
                       reshape = FALSE) {
      
      # Check if number of iteration is  non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
      
      # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
375
      predictor <- Predictor$new(private$handle)
376
      predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape)
377
378
379
380
381
382
      
    },
    
    # Transform into predictor
    to_predictor = function() {
      Predictor$new(private$handle)
Guolin Ke's avatar
Guolin Ke committed
383
    },
384
385
    
    # Used for save
386
    raw = NA,
387
388
    
    # Save model to temporary file for in-memory saving
389
    save = function() {
390
391
      
      # Create temporary file
392
      temp <- tempfile()
393
394
      
      # Save model to file
395
      lgb.save(self, temp)
396
397
      
      # Overwrite model in object
398
      self$raw <- readChar(temp, file.info(temp)$size)
399
400
      
      # Remove temporary file
401
      file.remove(temp)
402
      
403
    }
404
    
Guolin Ke's avatar
Guolin Ke committed
405
406
  ),
  private = list(
407
408
409
410
411
412
413
414
415
416
417
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
    num_class = 1,
    num_dataset = 0,
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
418
    higher_better_inner_eval = NULL,
419
420
421
422
423
    
    # Predict data
    inner_predict = function(idx) {
      
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
424
      data_name <- private$name_train_set
425
426
427
428
429
430
431
      
      # Check for id bigger than 1
      if (idx > 1) {
        data_name <- private$name_valid_sets[[idx - 1]]
      }
      
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
432
433
434
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
435
436
      
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
437
      if (is.null(private$predict_buffer[[data_name]])) {
438
439
        
        # Store predictions
440
        npred <- 0L
441
        npred <- lgb.call("LGBM_BoosterGetNumPredict_R",
442
443
444
445
446
                          ret = npred,
                          private$handle,
                          as.integer(idx - 1))
        private$predict_buffer[[data_name]] <- numeric(npred)
        
Guolin Ke's avatar
Guolin Ke committed
447
      }
448
449
      
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
450
      if (!private$is_predicted_cur_iter[[idx]]) {
451
452
453
454
455
456
        
        # Use buffer
        private$predict_buffer[[data_name]] <- lgb.call("LGBM_BoosterGetPredict_R",
                                                        ret = private$predict_buffer[[data_name]],
                                                        private$handle,
                                                        as.integer(idx - 1))
Guolin Ke's avatar
Guolin Ke committed
457
458
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
459
460
461
      
      # Return prediction buffer
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
462
    },
463
464
    
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
465
    get_eval_info = function() {
466
467
      
      # Check for evaluation names emptiness
Guolin Ke's avatar
Guolin Ke committed
468
      if (is.null(private$eval_names)) {
469
470
471
472
473
474
        
        # Get evaluation names
        names <- lgb.call.return.str("LGBM_BoosterGetEvalNames_R",
                                     private$handle)
        
        # Check names' length
475
        if (nchar(names) > 0) {
476
477
          
          # Parse and store privately names
Guolin Ke's avatar
Guolin Ke committed
478
479
          names <- strsplit(names, "\t")[[1]]
          private$eval_names <- names
480
          private$higher_better_inner_eval <- rep(FALSE, length(names))
481
482
          
          # Loop through each name to pick up evaluation (and parse ndcg manually)
483
          for (i in seq_along(names)) {
484
            
485
            if ((names[i] == "auc") | grepl("^ndcg", names[i])) {
Guolin Ke's avatar
Guolin Ke committed
486
487
              private$higher_better_inner_eval[i] <- TRUE
            }
488
            
Guolin Ke's avatar
Guolin Ke committed
489
          }
490
          
Guolin Ke's avatar
Guolin Ke committed
491
        }
492
        
Guolin Ke's avatar
Guolin Ke committed
493
      }
494
495
496
497
      
      # Return evaluation names
      return(private$eval_names)
      
Guolin Ke's avatar
Guolin Ke committed
498
    },
499
500
    
    # Perform inner evaluation
Guolin Ke's avatar
Guolin Ke committed
501
    inner_eval = function(data_name, data_idx, feval = NULL) {
502
503
      
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
504
505
506
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
507
508
      
      # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
509
      private$get_eval_info()
510
511
      
      # Prepare return
Guolin Ke's avatar
Guolin Ke committed
512
      ret <- list()
513
514
      
      # Check evaluation names existence
Guolin Ke's avatar
Guolin Ke committed
515
      if (length(private$eval_names) > 0) {
516
517
518
519
520
521
522
523
524
        
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
        tmp_vals <- lgb.call("LGBM_BoosterGetEval_R",
                             ret = tmp_vals,
                             private$handle,
                             as.integer(data_idx - 1))
        
        # Loop through all evaluation names
525
        for (i in seq_along(private$eval_names)) {
526
527
528
529
530
531
          
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
532
          res$higher_better <- private$higher_better_inner_eval[i]
533
534
          ret <- append(ret, list(res))
          
Guolin Ke's avatar
Guolin Ke committed
535
        }
536
        
Guolin Ke's avatar
Guolin Ke committed
537
      }
538
539
      
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
540
      if (!is.null(feval)) {
541
542
        
        # Check if evaluation metric is a function
543
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
544
545
          stop("lgb.Booster.eval: feval should be a function")
        }
546
547
        
        # Prepare data
Guolin Ke's avatar
Guolin Ke committed
548
        data <- private$train_set
549
550
551
552
553
554
555
        
        # Check if data to assess is existing differently
        if (data_idx > 1) {
          data <- private$valid_sets[[data_idx - 1]]
        }
        
        # Perform function evaluation
556
        res <- feval(private$inner_predict(data_idx), data)
557
558
559
        
        # Check for name correctness
        if(is.null(res$name) | is.null(res$value) |  is.null(res$higher_better)) {
560
561
562
          stop("lgb.Booster.eval: custom eval function should return a 
            list with attribute (name, value, higher_better)");
        }
563
564
        
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
565
        res$data_name <- data_name
566
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
567
      }
568
569
570
571
      
      # Return ret
      return(ret)
      
Guolin Ke's avatar
Guolin Ke committed
572
    }
573
    
Guolin Ke's avatar
Guolin Ke committed
574
575
576
577
578
  )
)


#' Predict method for LightGBM model
579
#'
Guolin Ke's avatar
Guolin Ke committed
580
#' Predicted values based on class \code{lgb.Booster}
581
#'
Guolin Ke's avatar
Guolin Ke committed
582
583
584
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
585
586
#' @param rawscore whether the prediction should be returned in the for of original untransformed
#'        sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} for
Guolin Ke's avatar
Guolin Ke committed
587
#'        logistic regression would result in predictions for log-odds instead of probabilities.
588
#' @param predleaf whether predict leaf index instead.
Guolin Ke's avatar
Guolin Ke committed
589
#' @param header only used for prediction for text file. True if text file has header
590
591
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
#'        prediction outputs per case.
Guolin Ke's avatar
Guolin Ke committed
592

593
#' @return
Guolin Ke's avatar
Guolin Ke committed
594
#' For regression or binary classification, it returns a vector of length \code{nrows(data)}.
595
596
#' For multiclass classification, either a \code{num_class * nrows(data)} vector or
#' a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
Guolin Ke's avatar
Guolin Ke committed
597
#' the \code{reshape} value.
598
599
#'
#' When \code{predleaf = TRUE}, the output is a matrix object with the
Guolin Ke's avatar
Guolin Ke committed
600
#' number of columns corresponding to the number of trees.
601
#' 
Guolin Ke's avatar
Guolin Ke committed
602
#' @examples
603
#' \dontrun{
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                    dtrain,
#'                    100,
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
#'                    early_stopping_rounds = 10)
#' preds <- predict(model, test$data)
621
#' }
622
#' 
Guolin Ke's avatar
Guolin Ke committed
623
624
#' @rdname predict.lgb.Booster
#' @export
625
predict.lgb.Booster <- function(object, data,
Guolin Ke's avatar
Guolin Ke committed
626
                        num_iteration = NULL,
627
628
629
630
631
632
                        rawscore = FALSE,
                        predleaf = FALSE,
                        header = FALSE,
                        reshape = FALSE) {
  
  # Check booster existence
633
634
  if (!lgb.is.Booster(object)) {
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
635
  }
636
637
638
639
640
641
642
643
  
  # Return booster predictions
  object$predict(data,
                 num_iteration,
                 rawscore,
                 predleaf,
                 header,
                 reshape)
Guolin Ke's avatar
Guolin Ke committed
644
645
646
}

#' Load LightGBM model
647
#'
Guolin Ke's avatar
Guolin Ke committed
648
#' Load LightGBM model from saved model file
649
#'
Guolin Ke's avatar
Guolin Ke committed
650
#' @param filename path of model file
651
#'
Guolin Ke's avatar
Guolin Ke committed
652
#' @return booster
653
#' 
Guolin Ke's avatar
Guolin Ke committed
654
#' @examples
655
#' \dontrun{
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                    dtrain,
#'                    100,
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
#'                    early_stopping_rounds = 10)
#' lgb.save(model, "model.txt")
#' load_booster <- lgb.load("model.txt")
674
#' }
675
#' 
676
#' @rdname lgb.load
Guolin Ke's avatar
Guolin Ke committed
677
678
#' @export
lgb.load <- function(filename){
679
680
681
682
683
684
685
  
  # Check if file name is character or not
  if (!is.character(filename)) {
    stop("lgb.load: filename should be character")
  }
  
  # Return new booster
686
  Booster$new(modelfile = filename)
687
  
Guolin Ke's avatar
Guolin Ke committed
688
689
690
}

#' Save LightGBM model
691
#'
Guolin Ke's avatar
Guolin Ke committed
692
#' Save LightGBM model
693
#'
Guolin Ke's avatar
Guolin Ke committed
694
695
696
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
697
#'
Guolin Ke's avatar
Guolin Ke committed
698
#' @return booster
699
#' 
Guolin Ke's avatar
Guolin Ke committed
700
#' @examples
701
#' \dontrun{
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                    dtrain,
#'                    100,
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
#'                    early_stopping_rounds = 10)
#' lgb.save(model, "model.txt")
719
#' }
720
#' 
721
#' @rdname lgb.save
Guolin Ke's avatar
Guolin Ke committed
722
#' @export
723
lgb.save <- function(booster, filename, num_iteration = NULL){
724
725
726
727
728
729
730
731
732
733
734
735
  
  # Check if booster is booster
  if (!lgb.is.Booster(booster)) {
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
  
  # Check if file name is character
  if (!is.character(filename)) {
    stop("lgb.save: filename should be a character")
  }
  
  # Store booster
Guolin Ke's avatar
Guolin Ke committed
736
  booster$save_model(filename, num_iteration)
737
  
Guolin Ke's avatar
Guolin Ke committed
738
739
740
}

#' Dump LightGBM model to json
741
#'
Guolin Ke's avatar
Guolin Ke committed
742
#' Dump LightGBM model to json
743
#'
Guolin Ke's avatar
Guolin Ke committed
744
745
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
746
#'
Guolin Ke's avatar
Guolin Ke committed
747
#' @return json format of model
748
#' 
Guolin Ke's avatar
Guolin Ke committed
749
#' @examples
750
#' \dontrun{
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                   dtrain,
#'                    100,
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
#'                    early_stopping_rounds = 10)
#' json_model <- lgb.dump(model)
768
#' }
769
#' 
770
#' @rdname lgb.dump
Guolin Ke's avatar
Guolin Ke committed
771
#' @export
772
lgb.dump <- function(booster, num_iteration = NULL){
773
774
775
776
777
778
779
  
  # Check if booster is booster
  if (!lgb.is.Booster(booster)) {
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
  
  # Return booster at requested iteration
Guolin Ke's avatar
Guolin Ke committed
780
  booster$dump_model(num_iteration)
781
  
Guolin Ke's avatar
Guolin Ke committed
782
783
784
}

#' Get record evaluation result from booster
785
#'
Guolin Ke's avatar
Guolin Ke committed
786
787
788
789
790
791
#' Get record evaluation result from booster
#' @param booster Object of class \code{lgb.Booster}
#' @param data_name name of dataset
#' @param eval_name name of evaluation
#' @param iters iterations, NULL will return all
#' @param is_err TRUE will return evaluation error instead
792
#' 
Guolin Ke's avatar
Guolin Ke committed
793
#' @return vector of evaluation result
794
#' 
Guolin Ke's avatar
Guolin Ke committed
795
796
#' @rdname lgb.get.eval.result
#' @export
797
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
798
799
  
  # Check if booster is booster
800
801
  if (!lgb.is.Booster(booster)) {
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
802
  }
803
804
  
  # Check if data and evaluation name are characters or not
805
806
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
807
  }
808
809
  
  # Check if recorded evaluation is existing
810
  if (is.null(booster$record_evals[[data_name]])) {
Guolin Ke's avatar
Guolin Ke committed
811
812
    stop("lgb.get.eval.result: wrong data name")
  }
813
814
  
  # Check if evaluation result is existing
815
  if (is.null(booster$record_evals[[data_name]][[eval_name]])) {
Guolin Ke's avatar
Guolin Ke committed
816
817
    stop("lgb.get.eval.result: wrong eval name")
  }
818
819
  
  # Create result
Guolin Ke's avatar
Guolin Ke committed
820
  result <- booster$record_evals[[data_name]][[eval_name]]$eval
821
822
  
  # Check if error is requested
823
  if (is_err) {
Guolin Ke's avatar
Guolin Ke committed
824
825
    result <- booster$record_evals[[data_name]][[eval_name]]$eval_err
  }
826
827
  
  # Check if iteration is non existant
828
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
829
830
    return(as.numeric(result))
  }
831
832
  
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
833
834
835
  iters <- as.integer(iters)
  delta <- booster$record_evals$start_iter - 1
  iters <- iters - delta
836
837
  
  # Return requested result
838
  as.numeric(result[iters])
Guolin Ke's avatar
Guolin Ke committed
839
}