lgb.Booster.R 27 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
Booster <- R6Class(
2
  classname = "lgb.Booster",
3
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
4
  public = list(
5
6
    
    best_iter = -1,
Laurae's avatar
Laurae committed
7
    best_score = -1,
Guolin Ke's avatar
Guolin Ke committed
8
    record_evals = list(),
9
10
11
12
13
    
    # Finalize will free up the handles
    finalize = function() {
      
      # Check the need for freeing handle
14
      if (!lgb.is.null.handle(private$handle)) {
15
16
        
        # Freeing up handle
17
        lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
Guolin Ke's avatar
Guolin Ke committed
18
        private$handle <- NULL
19
        
Guolin Ke's avatar
Guolin Ke committed
20
      }
21
      
22
    },
23
24
25
    
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
26
27
                          train_set = NULL,
                          modelfile = NULL,
28
                          model_str = NULL,
Guolin Ke's avatar
Guolin Ke committed
29
                          ...) {
30
31
32
      
      # Create parameters and handle
      params <- append(params, list(...))
Guolin Ke's avatar
Guolin Ke committed
33
      params_str <- lgb.params2str(params)
Guolin Ke's avatar
Guolin Ke committed
34
      handle <- 0.0
35
36
37
      
      # Attempts to create a handle for the dataset
      try({
38
        
39
40
        # Check if training dataset is not null
        if (!is.null(train_set)) {
41
          
42
43
44
45
          # Check if training dataset is lgb.Dataset or not
          if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
46
          
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
          # Store booster handle
          handle <- lgb.call("LGBM_BoosterCreate_R", ret = handle, train_set$.__enclos_env__$private$get_handle(), params_str)
          
          # Create private booster information
          private$train_set <- train_set
          private$num_dataset <- 1
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
          
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
            
            # Merge booster
            lgb.call("LGBM_BoosterMerge_R",
                     ret = NULL,
                     handle,
                     private$init_predictor$.__enclos_env__$private$handle)
            
          }
          
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
          
        } else if (!is.null(modelfile)) {
          
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
          
          # Create booster from model
          handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R",
                             ret = handle,
                             lgb.c_str(modelfile))
          
        } else if (!is.null(model_str)) {
          
          # Do we have a model_str as character?
84
85
86
87
88
89
90
91
          if (!is.character(model_str)) {
            stop("lgb.Booster: Can only use a string as model_str")
          }
          
          # Create booster from model
          handle <- lgb.call("LGBM_BoosterLoadModelFromString_R",
                             ret = handle,
                             lgb.c_str(model_str))
92
93
94
95
96
97
98
          
        } else {
          
          # Booster non existent
          stop("lgb.Booster: Need at least either training dataset, model file, or model_str to create booster instance")
          
        }
99
        
100
101
102
      })
      
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
Guolin Ke's avatar
Guolin Ke committed
103
      if (lgb.is.null.handle(handle)) {
104
        
Guolin Ke's avatar
Guolin Ke committed
105
        stop("lgb.Booster: cannot create Booster handle")
106
        
Guolin Ke's avatar
Guolin Ke committed
107
      } else {
108
        
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
114
115
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
        private$num_class <- lgb.call("LGBM_BoosterGetNumClasses_R",
                                      ret = private$num_class,
                                      private$handle)
116
        
Guolin Ke's avatar
Guolin Ke committed
117
      }
118
      
Guolin Ke's avatar
Guolin Ke committed
119
    },
120
121
    
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
122
    set_train_data_name = function(name) {
123
124
      
      # Set name
Guolin Ke's avatar
Guolin Ke committed
125
      private$name_train_set <- name
126
      return(invisible(self))
127
      
Guolin Ke's avatar
Guolin Ke committed
128
    },
129
130
    
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
131
    add_valid = function(data, name) {
132
133
      
      # Check if data is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
134
      if (!lgb.check.r6.class(data, "lgb.Dataset")) {
135
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
136
      }
137
138
      
      # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
139
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
140
        stop("lgb.Booster.add_valid: Failed to add validation data; you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
141
      }
142
143
      
      # Check if names are character
144
145
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
146
      }
147
148
149
150
151
152
153
154
155
156
157
158
159
160
      
      # Add validation data to booster
      lgb.call("LGBM_BoosterAddValidData_R",
               ret = NULL,
               private$handle,
               data$.__enclos_env__$private$get_handle())
      
      # Store private information
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
      private$num_dataset <- private$num_dataset + 1
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
      
      # Return self
161
      return(invisible(self))
162
      
Guolin Ke's avatar
Guolin Ke committed
163
    },
164
165
    
    # Reset parameters of booster
Guolin Ke's avatar
Guolin Ke committed
166
    reset_parameter = function(params, ...) {
167
168
169
      
      # Append parameters
      params <- append(params, list(...))
Guolin Ke's avatar
Guolin Ke committed
170
      params_str <- algb.params2str(params)
171
172
173
174
175
176
177
178
      
      # Reset parameters
      lgb.call("LGBM_BoosterResetParameter_R",
               ret = NULL,
               private$handle,
               params_str)
      
      # Return self
179
      return(invisible(self))
180
      
Guolin Ke's avatar
Guolin Ke committed
181
    },
182
183
    
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
184
    update = function(train_set = NULL, fobj = NULL) {
185
186
      
      # Check if training set is not null
Guolin Ke's avatar
Guolin Ke committed
187
      if (!is.null(train_set)) {
188
189
        
        # Check if training set is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
190
191
192
        if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
193
194
        
        # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
195
        if (!identical(train_set$predictor, private$init_predictor)) {
196
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
197
        }
198
199
200
201
202
203
204
205
        
        # Reset training data on booster
        lgb.call("LGBM_BoosterResetTrainingData_R",
                 ret = NULL,
                 private$handle,
                 train_set$.__enclos_env__$private$get_handle())
        
        # Store private train set
Guolin Ke's avatar
Guolin Ke committed
206
        private$train_set = train_set
207
        
Guolin Ke's avatar
Guolin Ke committed
208
      }
209
210
      
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
211
      if (is.null(fobj)) {
212
213
        
        # Boost iteration from known objective
214
        ret <- lgb.call("LGBM_BoosterUpdateOneIter_R", ret = NULL, private$handle)
215
        
Guolin Ke's avatar
Guolin Ke committed
216
      } else {
217
218
219
220
221
222
223
        
        # Check if objective is function
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
        
        # Perform objective calculation
Guolin Ke's avatar
Guolin Ke committed
224
        gpair <- fobj(private$inner_predict(1), private$train_set)
225
226
        
        # Check for gradient and hessian as list
227
        if(is.null(gpair$grad) || is.null(gpair$hess)){
228
229
230
          stop("lgb.Booster.update: custom objective should 
            return a list with attributes (hess, grad)")
        }
231
232
233
234
235
236
237
238
239
        
        # Return custom boosting gradient/hessian
        ret <- lgb.call("LGBM_BoosterUpdateOneIterCustom_R",
                        ret = NULL,
                        private$handle,
                        gpair$grad,
                        gpair$hess,
                        length(gpair$grad))
        
Guolin Ke's avatar
Guolin Ke committed
240
      }
241
242
      
      # Loop through each iteration
243
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
244
245
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
246
247
248
      
      return(ret)
      
Guolin Ke's avatar
Guolin Ke committed
249
    },
250
251
    
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
252
    rollback_one_iter = function() {
253
254
255
256
257
258
259
      
      # Return one iteration behind
      lgb.call("LGBM_BoosterRollbackOneIter_R",
               ret = NULL,
               private$handle)
      
      # Loop through each iteration
260
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
261
262
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
263
264
      
      # Return self
265
      return(invisible(self))
266
      
Guolin Ke's avatar
Guolin Ke committed
267
    },
268
269
    
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
270
    current_iter = function() {
271
      
272
      cur_iter <- 0L
273
274
275
276
      lgb.call("LGBM_BoosterGetCurrentIteration_R",
               ret = cur_iter,
               private$handle)
      
Guolin Ke's avatar
Guolin Ke committed
277
    },
278
279
    
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
280
    eval = function(data, name, feval = NULL) {
281
282
      
      # Check if dataset is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
283
      if (!lgb.check.r6.class(data, "lgb.Dataset")) {
284
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
285
      }
286
287
      
      # Check for identical data
Guolin Ke's avatar
Guolin Ke committed
288
      data_idx <- 0
289
290
291
292
293
      if (identical(data, private$train_set)) {
        data_idx <- 1
      } else {
        
        # Check for validation data
294
        if (length(private$valid_sets) > 0) {
295
296
          
          # Loop through each validation set
297
          for (i in seq_along(private$valid_sets)) {
298
299
            
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
300
            if (identical(data, private$valid_sets[[i]])) {
301
302
              
              # Found identical data, skip
Guolin Ke's avatar
Guolin Ke committed
303
304
              data_idx <- i + 1
              break
305
              
Guolin Ke's avatar
Guolin Ke committed
306
            }
307
            
Guolin Ke's avatar
Guolin Ke committed
308
          }
309
          
Guolin Ke's avatar
Guolin Ke committed
310
        }
311
        
Guolin Ke's avatar
Guolin Ke committed
312
      }
313
314
      
      # Check if evaluation was not done
Guolin Ke's avatar
Guolin Ke committed
315
      if (data_idx == 0) {
316
317
        
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
318
319
        self$add_valid(data, name)
        data_idx <- private$num_dataset
320
        
Guolin Ke's avatar
Guolin Ke committed
321
      }
322
323
      
      # Evaluate data
324
      private$inner_eval(name, data_idx, feval)
325
      
Guolin Ke's avatar
Guolin Ke committed
326
    },
327
328
    
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
329
    eval_train = function(feval = NULL) {
330
      private$inner_eval(private$name_train_set, 1, feval)
Guolin Ke's avatar
Guolin Ke committed
331
    },
332
333
    
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
334
    eval_valid = function(feval = NULL) {
335
336
      
      # Create ret list
Guolin Ke's avatar
Guolin Ke committed
337
      ret = list()
338
339
340
341
342
343
344
      
      # Check if validation is empty
      if (length(private$valid_sets) <= 0) {
        return(ret)
      }
      
      # Loop through each validation set
345
346
      for (i in seq_along(private$valid_sets)) {
        ret <- append(ret, private$inner_eval(private$name_valid_sets[[i]], i + 1, feval))
Guolin Ke's avatar
Guolin Ke committed
347
      }
348
349
350
351
      
      # Return ret
      return(ret)
      
Guolin Ke's avatar
Guolin Ke committed
352
    },
353
354
    
    # Save model
Guolin Ke's avatar
Guolin Ke committed
355
    save_model = function(filename, num_iteration = NULL) {
356
357
358
359
360
361
362
363
364
365
366
367
368
369
      
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
      
      # Save booster model
      lgb.call("LGBM_BoosterSaveModel_R",
               ret = NULL,
               private$handle,
               as.integer(num_iteration),
               lgb.c_str(filename))
      
      # Return self
370
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
371
    },
372
    
373
374
375
376
377
378
379
380
381
382
    # Save model to string
    save_model_to_string = function(num_iteration = NULL) {
      
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
      
      # Return model string
      return(lgb.call.return.str("LGBM_BoosterSaveModelToString_R",
383
384
                                 private$handle,
                                 as.integer(num_iteration)))
385
386
387
      
    },
    
388
    # Dump model in memory
Guolin Ke's avatar
Guolin Ke committed
389
    dump_model = function(num_iteration = NULL) {
390
391
392
393
394
395
396
397
398
399
400
      
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
      
      # Return dumped model
      lgb.call.return.str("LGBM_BoosterDumpModel_R",
                          private$handle,
                          as.integer(num_iteration))
      
Guolin Ke's avatar
Guolin Ke committed
401
    },
402
403
    
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
404
    predict = function(data,
405
406
407
408
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
                       header = FALSE,
409
                       reshape = FALSE, ...) {
410
411
412
413
414
415
416
      
      # Check if number of iteration is  non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
      
      # Predict on new data
417
      predictor <- Predictor$new(private$handle, ...)
418
      predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape)
419
420
421
422
423
424
      
    },
    
    # Transform into predictor
    to_predictor = function() {
      Predictor$new(private$handle)
Guolin Ke's avatar
Guolin Ke committed
425
    },
426
427
    
    # Used for save
428
    raw = NA,
429
430
    
    # Save model to temporary file for in-memory saving
431
    save = function() {
432
433
      
      # Overwrite model in object
434
      self$raw <- self$save_model_to_string(NULL)
435
      
436
    }
437
    
Guolin Ke's avatar
Guolin Ke committed
438
439
  ),
  private = list(
440
441
442
443
444
445
446
447
448
449
450
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
    num_class = 1,
    num_dataset = 0,
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
451
    higher_better_inner_eval = NULL,
452
453
454
455
456
    
    # Predict data
    inner_predict = function(idx) {
      
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
457
      data_name <- private$name_train_set
458
459
460
461
462
463
464
      
      # Check for id bigger than 1
      if (idx > 1) {
        data_name <- private$name_valid_sets[[idx - 1]]
      }
      
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
465
466
467
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
468
469
      
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
470
      if (is.null(private$predict_buffer[[data_name]])) {
471
472
        
        # Store predictions
473
        npred <- 0L
474
        npred <- lgb.call("LGBM_BoosterGetNumPredict_R",
475
476
477
478
479
                          ret = npred,
                          private$handle,
                          as.integer(idx - 1))
        private$predict_buffer[[data_name]] <- numeric(npred)
        
Guolin Ke's avatar
Guolin Ke committed
480
      }
481
482
      
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
483
      if (!private$is_predicted_cur_iter[[idx]]) {
484
485
486
487
488
489
        
        # Use buffer
        private$predict_buffer[[data_name]] <- lgb.call("LGBM_BoosterGetPredict_R",
                                                        ret = private$predict_buffer[[data_name]],
                                                        private$handle,
                                                        as.integer(idx - 1))
Guolin Ke's avatar
Guolin Ke committed
490
491
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
492
493
494
      
      # Return prediction buffer
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
495
    },
496
497
    
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
498
    get_eval_info = function() {
499
500
      
      # Check for evaluation names emptiness
Guolin Ke's avatar
Guolin Ke committed
501
      if (is.null(private$eval_names)) {
502
503
504
505
506
507
        
        # Get evaluation names
        names <- lgb.call.return.str("LGBM_BoosterGetEvalNames_R",
                                     private$handle)
        
        # Check names' length
508
        if (nchar(names) > 0) {
509
510
          
          # Parse and store privately names
Guolin Ke's avatar
Guolin Ke committed
511
512
          names <- strsplit(names, "\t")[[1]]
          private$eval_names <- names
513
          private$higher_better_inner_eval <- grepl("^ndcg|^auc$", names)
514
          
Guolin Ke's avatar
Guolin Ke committed
515
        }
516
        
Guolin Ke's avatar
Guolin Ke committed
517
      }
518
519
520
521
      
      # Return evaluation names
      return(private$eval_names)
      
Guolin Ke's avatar
Guolin Ke committed
522
    },
523
524
    
    # Perform inner evaluation
Guolin Ke's avatar
Guolin Ke committed
525
    inner_eval = function(data_name, data_idx, feval = NULL) {
526
527
      
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
528
529
530
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
531
532
      
      # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
533
      private$get_eval_info()
534
535
      
      # Prepare return
Guolin Ke's avatar
Guolin Ke committed
536
      ret <- list()
537
538
      
      # Check evaluation names existence
Guolin Ke's avatar
Guolin Ke committed
539
      if (length(private$eval_names) > 0) {
540
541
542
543
544
545
546
547
548
        
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
        tmp_vals <- lgb.call("LGBM_BoosterGetEval_R",
                             ret = tmp_vals,
                             private$handle,
                             as.integer(data_idx - 1))
        
        # Loop through all evaluation names
549
        for (i in seq_along(private$eval_names)) {
550
551
552
553
554
555
          
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
556
          res$higher_better <- private$higher_better_inner_eval[i]
557
558
          ret <- append(ret, list(res))
          
Guolin Ke's avatar
Guolin Ke committed
559
        }
560
        
Guolin Ke's avatar
Guolin Ke committed
561
      }
562
563
      
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
564
      if (!is.null(feval)) {
565
566
        
        # Check if evaluation metric is a function
567
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
568
569
          stop("lgb.Booster.eval: feval should be a function")
        }
570
571
        
        # Prepare data
Guolin Ke's avatar
Guolin Ke committed
572
        data <- private$train_set
573
574
575
576
577
578
579
        
        # Check if data to assess is existing differently
        if (data_idx > 1) {
          data <- private$valid_sets[[data_idx - 1]]
        }
        
        # Perform function evaluation
580
        res <- feval(private$inner_predict(data_idx), data)
581
582
        
        # Check for name correctness
583
        if(is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
584
585
586
          stop("lgb.Booster.eval: custom eval function should return a 
            list with attribute (name, value, higher_better)");
        }
587
588
        
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
589
        res$data_name <- data_name
590
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
591
      }
592
593
594
595
      
      # Return ret
      return(ret)
      
Guolin Ke's avatar
Guolin Ke committed
596
    }
597
    
Guolin Ke's avatar
Guolin Ke committed
598
599
600
601
602
  )
)


#' Predict method for LightGBM model
603
#'
Guolin Ke's avatar
Guolin Ke committed
604
#' Predicted values based on class \code{lgb.Booster}
605
#'
Guolin Ke's avatar
Guolin Ke committed
606
607
608
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
609
610
#' @param rawscore whether the prediction should be returned in the for of original untransformed
#'        sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} for
Guolin Ke's avatar
Guolin Ke committed
611
#'        logistic regression would result in predictions for log-odds instead of probabilities.
612
#' @param predleaf whether predict leaf index instead.
Guolin Ke's avatar
Guolin Ke committed
613
#' @param header only used for prediction for text file. True if text file has header
614
615
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
#'        prediction outputs per case.
Guolin Ke's avatar
Guolin Ke committed
616

617
#' @return
Guolin Ke's avatar
Guolin Ke committed
618
#' For regression or binary classification, it returns a vector of length \code{nrows(data)}.
619
620
#' For multiclass classification, either a \code{num_class * nrows(data)} vector or
#' a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
Guolin Ke's avatar
Guolin Ke committed
621
#' the \code{reshape} value.
622
623
#'
#' When \code{predleaf = TRUE}, the output is a matrix object with the
Guolin Ke's avatar
Guolin Ke committed
624
#' number of columns corresponding to the number of trees.
625
#' 
Guolin Ke's avatar
Guolin Ke committed
626
#' @examples
627
#' \dontrun{
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                    dtrain,
#'                    100,
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
#'                    early_stopping_rounds = 10)
#' preds <- predict(model, test$data)
645
#' }
646
#' 
Guolin Ke's avatar
Guolin Ke committed
647
648
#' @rdname predict.lgb.Booster
#' @export
649
predict.lgb.Booster <- function(object, data,
Guolin Ke's avatar
Guolin Ke committed
650
                        num_iteration = NULL,
651
652
653
                        rawscore = FALSE,
                        predleaf = FALSE,
                        header = FALSE,
654
                        reshape = FALSE, ...) {
655
656
  
  # Check booster existence
657
658
  if (!lgb.is.Booster(object)) {
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
659
  }
660
661
662
663
664
665
666
  
  # Return booster predictions
  object$predict(data,
                 num_iteration,
                 rawscore,
                 predleaf,
                 header,
667
                 reshape, ...)
Guolin Ke's avatar
Guolin Ke committed
668
669
670
}

#' Load LightGBM model
671
#'
672
673
674
#' Load LightGBM model from saved model file or string
#' Load LightGBM takes in either a file path or model string
#' If both are provided, Load will default to loading from file
675
#'
Guolin Ke's avatar
Guolin Ke committed
676
#' @param filename path of model file
677
#' @param model_str a str containing the model
678
#'
679
#' @return lgb.Booster
680
#' 
Guolin Ke's avatar
Guolin Ke committed
681
#' @examples
682
#' \dontrun{
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                    dtrain,
#'                    100,
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
#'                    early_stopping_rounds = 10)
#' lgb.save(model, "model.txt")
700
701
702
#' load_booster <- lgb.load(filename = "model.txt")
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
703
#' }
704
#' 
705
#' @rdname lgb.load
Guolin Ke's avatar
Guolin Ke committed
706
#' @export
707
lgb.load <- function(filename = NULL, model_str = NULL){
708
  
709
710
711
712
713
714
  if (is.null(filename) && is.null(model_str)) {
    stop("lgb.load: either filename or model_str must be given")
  }
  
  # Load from filename
  if (!is.null(filename) && !is.character(filename)) {
715
716
717
718
    stop("lgb.load: filename should be character")
  }
  
  # Return new booster
719
  if (!is.null(filename) && !file.exists(filename)) stop("lgb.load: file does not exist for supplied filename")
720
  if (!is.null(filename)) return(invisible(Booster$new(modelfile = filename)))
721
722
723
724
725
726
  
  # Load from model_str
  if (!is.null(model_str) && !is.character(model_str)) {
    stop("lgb.load: model_str should be character")
  }    
  # Return new booster
727
  if (!is.null(model_str)) return(invisible(Booster$new(model_str = model_str)))
728
  
Guolin Ke's avatar
Guolin Ke committed
729
730
731
}

#' Save LightGBM model
732
#'
Guolin Ke's avatar
Guolin Ke committed
733
#' Save LightGBM model
734
#'
Guolin Ke's avatar
Guolin Ke committed
735
736
737
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
738
#'
739
#' @return lgb.Booster
740
#' 
Guolin Ke's avatar
Guolin Ke committed
741
#' @examples
742
#' \dontrun{
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                    dtrain,
#'                    100,
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
#'                    early_stopping_rounds = 10)
#' lgb.save(model, "model.txt")
760
#' }
761
#' 
762
#' @rdname lgb.save
Guolin Ke's avatar
Guolin Ke committed
763
#' @export
764
lgb.save <- function(booster, filename, num_iteration = NULL){
765
766
767
768
769
770
771
772
773
774
775
776
  
  # Check if booster is booster
  if (!lgb.is.Booster(booster)) {
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
  
  # Check if file name is character
  if (!is.character(filename)) {
    stop("lgb.save: filename should be a character")
  }
  
  # Store booster
777
  invisible(booster$save_model(filename, num_iteration))
778
  
Guolin Ke's avatar
Guolin Ke committed
779
780
781
}

#' Dump LightGBM model to json
782
#'
Guolin Ke's avatar
Guolin Ke committed
783
#' Dump LightGBM model to json
784
#'
Guolin Ke's avatar
Guolin Ke committed
785
786
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
787
#'
Guolin Ke's avatar
Guolin Ke committed
788
#' @return json format of model
789
#' 
Guolin Ke's avatar
Guolin Ke committed
790
#' @examples
791
#' \dontrun{
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                   dtrain,
#'                    100,
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
#'                    early_stopping_rounds = 10)
#' json_model <- lgb.dump(model)
809
#' }
810
#' 
811
#' @rdname lgb.dump
Guolin Ke's avatar
Guolin Ke committed
812
#' @export
813
lgb.dump <- function(booster, num_iteration = NULL){
814
815
816
817
818
819
820
  
  # Check if booster is booster
  if (!lgb.is.Booster(booster)) {
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
  
  # Return booster at requested iteration
Guolin Ke's avatar
Guolin Ke committed
821
  booster$dump_model(num_iteration)
822
  
Guolin Ke's avatar
Guolin Ke committed
823
824
825
}

#' Get record evaluation result from booster
826
#'
Guolin Ke's avatar
Guolin Ke committed
827
828
829
830
831
832
#' Get record evaluation result from booster
#' @param booster Object of class \code{lgb.Booster}
#' @param data_name name of dataset
#' @param eval_name name of evaluation
#' @param iters iterations, NULL will return all
#' @param is_err TRUE will return evaluation error instead
833
#' 
Guolin Ke's avatar
Guolin Ke committed
834
#' @return vector of evaluation result
835
#' 
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
#' @examples
#' \dontrun{
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                    dtrain,
#'                    100,
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
#'                    early_stopping_rounds = 10)
#' lgb.get.eval.result(model, "test", "l2")
#' }
#' 
Guolin Ke's avatar
Guolin Ke committed
857
858
#' @rdname lgb.get.eval.result
#' @export
859
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
860
861
  
  # Check if booster is booster
862
863
  if (!lgb.is.Booster(booster)) {
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
864
  }
865
866
  
  # Check if data and evaluation name are characters or not
867
868
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
869
  }
870
871
  
  # Check if recorded evaluation is existing
872
  if (is.null(booster$record_evals[[data_name]])) {
Guolin Ke's avatar
Guolin Ke committed
873
874
    stop("lgb.get.eval.result: wrong data name")
  }
875
876
  
  # Check if evaluation result is existing
877
  if (is.null(booster$record_evals[[data_name]][[eval_name]])) {
Guolin Ke's avatar
Guolin Ke committed
878
879
    stop("lgb.get.eval.result: wrong eval name")
  }
880
881
  
  # Create result
Guolin Ke's avatar
Guolin Ke committed
882
  result <- booster$record_evals[[data_name]][[eval_name]]$eval
883
884
  
  # Check if error is requested
885
  if (is_err) {
Guolin Ke's avatar
Guolin Ke committed
886
887
    result <- booster$record_evals[[data_name]][[eval_name]]$eval_err
  }
888
889
  
  # Check if iteration is non existant
890
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
891
892
    return(as.numeric(result))
  }
893
894
  
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
895
896
897
  iters <- as.integer(iters)
  delta <- booster$record_evals$start_iter - 1
  iters <- iters - delta
898
899
  
  # Return requested result
900
  as.numeric(result[iters])
Guolin Ke's avatar
Guolin Ke committed
901
}