lgb.Booster.R 31.8 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom R6 R6Class
2
#' @importFrom utils modifyList
James Lamb's avatar
James Lamb committed
3
Booster <- R6::R6Class(
4
  classname = "lgb.Booster",
5
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
6
  public = list(
7

8
    best_iter = -1L,
9
    best_score = NA_real_,
10
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
11
    record_evals = list(),
12

13
14
    # Finalize will free up the handles
    finalize = function() {
15
16
17
18
19
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
20
      return(invisible(NULL))
21
    },
22

23
24
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
25
26
                          train_set = NULL,
                          modelfile = NULL,
27
                          model_str = NULL) {
28

29
      # Create parameters and handle
30
      handle <- NULL
31

32
33
      # Attempts to create a handle for the dataset
      try({
34

35
36
37
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
38
          if (!lgb.is.Dataset(train_set)) {
39
40
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
41
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
42
          params <- utils::modifyList(params, train_set$get_params())
43
          params_str <- lgb.params2str(params = params)
44
          # Store booster handle
45
          handle <- .Call(
46
            LGBM_BoosterCreate_R
47
            , train_set_handle
48
49
            , params_str
          )
50

51
52
          # Create private booster information
          private$train_set <- train_set
53
          private$train_set_version <- train_set$.__enclos_env__$private$version
54
          private$num_dataset <- 1L
55
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
56

57
58
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
59

60
            # Merge booster
61
62
            .Call(
              LGBM_BoosterMerge_R
63
64
65
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
66

67
          }
68

69
70
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
71

72
        } else if (!is.null(modelfile)) {
73

74
75
76
77
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
78

79
80
          modelfile <- path.expand(modelfile)

81
          # Create booster from model
82
          handle <- .Call(
83
            LGBM_BoosterCreateFromModelfile_R
84
            , modelfile
85
          )
86

87
        } else if (!is.null(model_str)) {
88

89
          # Do we have a model_str as character?
90
91
92
          if (!is.character(model_str)) {
            stop("lgb.Booster: Can only use a string as model_str")
          }
93

94
          # Create booster from model
95
          handle <- .Call(
96
            LGBM_BoosterLoadModelFromString_R
97
            , model_str
98
          )
99

100
        } else {
101

102
          # Booster non existent
103
104
105
106
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
107

108
        }
109

110
      })
111

112
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
113
      if (isTRUE(lgb.is.null.handle(x = handle))) {
114

Guolin Ke's avatar
Guolin Ke committed
115
        stop("lgb.Booster: cannot create Booster handle")
116

Guolin Ke's avatar
Guolin Ke committed
117
      } else {
118

Guolin Ke's avatar
Guolin Ke committed
119
120
121
122
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
123
124
        .Call(
          LGBM_BoosterGetNumClasses_R
125
          , private$handle
126
          , private$num_class
127
        )
128

Guolin Ke's avatar
Guolin Ke committed
129
      }
130

131
132
      self$params <- params

133
134
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
135
    },
136

137
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
138
    set_train_data_name = function(name) {
139

140
      # Set name
Guolin Ke's avatar
Guolin Ke committed
141
      private$name_train_set <- name
142
      return(invisible(self))
143

Guolin Ke's avatar
Guolin Ke committed
144
    },
145

146
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
147
    add_valid = function(data, name) {
148

149
      if (!lgb.is.Dataset(data)) {
150
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
151
      }
152

Guolin Ke's avatar
Guolin Ke committed
153
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
154
155
156
157
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
158
      }
159

160
161
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
162
      }
163

164
      # Add validation data to booster
165
166
      .Call(
        LGBM_BoosterAddValidData_R
167
168
169
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
170

171
172
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
173
      private$num_dataset <- private$num_dataset + 1L
174
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
175

176
      return(invisible(self))
177

Guolin Ke's avatar
Guolin Ke committed
178
    },
179

Guolin Ke's avatar
Guolin Ke committed
180
    reset_parameter = function(params, ...) {
181

182
183
184
185
186
187
188
189
190
191
      additional_params <- list(...)
      if (length(additional_params) > 0L) {
        warning(paste0(
          "Booster$reset_parameter(): Found the following passed through '...': "
          , paste(names(additional_params), collapse = ", ")
          , ". These will be used, but in future releases of lightgbm, this warning will become an error. "
          , "Add these to 'params' instead."
        ))
      }

192
      if (methods::is(self$params, "list")) {
193
        params <- utils::modifyList(self$params, params)
194
195
      }

196
      params <- utils::modifyList(params, additional_params)
197
      params_str <- lgb.params2str(params = params)
198

199
200
      .Call(
        LGBM_BoosterResetParameter_R
201
202
203
        , private$handle
        , params_str
      )
204
      self$params <- params
205

206
      return(invisible(self))
207

Guolin Ke's avatar
Guolin Ke committed
208
    },
209

210
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
211
    update = function(train_set = NULL, fobj = NULL) {
212

213
214
215
216
217
218
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

Guolin Ke's avatar
Guolin Ke committed
219
      if (!is.null(train_set)) {
220

221
        if (!lgb.is.Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
222
223
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
224

Guolin Ke's avatar
Guolin Ke committed
225
        if (!identical(train_set$predictor, private$init_predictor)) {
226
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
227
        }
228

229
230
        .Call(
          LGBM_BoosterResetTrainingData_R
231
232
233
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
234

235
        private$train_set <- train_set
236
        private$train_set_version <- train_set$.__enclos_env__$private$version
237

Guolin Ke's avatar
Guolin Ke committed
238
      }
239

240
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
241
      if (is.null(fobj)) {
242
243
244
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
245
        # Boost iteration from known objective
246
247
        .Call(
          LGBM_BoosterUpdateOneIter_R
248
249
          , private$handle
        )
250

Guolin Ke's avatar
Guolin Ke committed
251
      } else {
252

253
254
255
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
256
        if (!private$set_objective_to_none) {
257
          self$reset_parameter(params = list(objective = "none"))
258
          private$set_objective_to_none <- TRUE
259
        }
260
        # Perform objective calculation
261
        gpair <- fobj(private$inner_predict(1L), private$train_set)
262

263
        # Check for gradient and hessian as list
264
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
265
          stop("lgb.Booster.update: custom objective should
266
267
            return a list with attributes (hess, grad)")
        }
268

269
        # Return custom boosting gradient/hessian
270
271
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
272
273
274
275
276
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
277

Guolin Ke's avatar
Guolin Ke committed
278
      }
279

280
      # Loop through each iteration
281
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
282
283
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
284

285
      return(invisible(self))
286

Guolin Ke's avatar
Guolin Ke committed
287
    },
288

289
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
290
    rollback_one_iter = function() {
291

292
293
      .Call(
        LGBM_BoosterRollbackOneIter_R
294
295
        , private$handle
      )
296

297
      # Loop through each iteration
298
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
299
300
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
301

302
      return(invisible(self))
303

Guolin Ke's avatar
Guolin Ke committed
304
    },
305

306
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
307
    current_iter = function() {
308

309
      cur_iter <- 0L
310
311
312
313
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
314
      )
315
      return(cur_iter)
316

Guolin Ke's avatar
Guolin Ke committed
317
    },
318

319
    # Get upper bound
320
    upper_bound = function() {
321

322
      upper_bound <- 0.0
323
324
325
326
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
327
      )
328
      return(upper_bound)
329
330
331
332

    },

    # Get lower bound
333
    lower_bound = function() {
334

335
      lower_bound <- 0.0
336
337
338
339
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
340
      )
341
      return(lower_bound)
342
343
344

    },

345
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
346
    eval = function(data, name, feval = NULL) {
347

348
      if (!lgb.is.Dataset(data)) {
349
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
350
      }
351

352
      # Check for identical data
353
      data_idx <- 0L
354
      if (identical(data, private$train_set)) {
355
        data_idx <- 1L
356
      } else {
357

358
        # Check for validation data
359
        if (length(private$valid_sets) > 0L) {
360

361
          for (i in seq_along(private$valid_sets)) {
362

363
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
364
            if (identical(data, private$valid_sets[[i]])) {
365

366
              # Found identical data, skip
367
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
368
              break
369

Guolin Ke's avatar
Guolin Ke committed
370
            }
371

Guolin Ke's avatar
Guolin Ke committed
372
          }
373

Guolin Ke's avatar
Guolin Ke committed
374
        }
375

Guolin Ke's avatar
Guolin Ke committed
376
      }
377

378
      # Check if evaluation was not done
379
      if (data_idx == 0L) {
380

381
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
382
383
        self$add_valid(data, name)
        data_idx <- private$num_dataset
384

Guolin Ke's avatar
Guolin Ke committed
385
      }
386

387
      # Evaluate data
388
389
390
391
392
393
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
394
      )
395

Guolin Ke's avatar
Guolin Ke committed
396
    },
397

398
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
399
    eval_train = function(feval = NULL) {
400
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
401
    },
402

403
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
404
    eval_valid = function(feval = NULL) {
405

406
      ret <- list()
407

408
      if (length(private$valid_sets) <= 0L) {
409
410
        return(ret)
      }
411

412
      for (i in seq_along(private$valid_sets)) {
413
414
        ret <- append(
          x = ret
415
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
416
        )
Guolin Ke's avatar
Guolin Ke committed
417
      }
418

419
      return(ret)
420

Guolin Ke's avatar
Guolin Ke committed
421
    },
422

423
    # Save model
424
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
425

426
427
428
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
429

430
431
      filename <- path.expand(filename)

432
433
      .Call(
        LGBM_BoosterSaveModel_R
434
435
        , private$handle
        , as.integer(num_iteration)
436
        , as.integer(feature_importance_type)
437
        , filename
438
      )
439

440
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
441
    },
442

443
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
444

445
446
447
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
448

449
      model_str <- .Call(
450
          LGBM_BoosterSaveModelToString_R
451
452
453
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
454
455
      )

456
      return(model_str)
457

458
    },
459

460
    # Dump model in memory
461
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
462

463
464
465
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
466

467
      model_str <- .Call(
468
469
470
471
472
473
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
      )

474
      return(model_str)
475

Guolin Ke's avatar
Guolin Ke committed
476
    },
477

478
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
479
    predict = function(data,
480
                       start_iteration = NULL,
481
482
483
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
484
                       predcontrib = FALSE,
485
                       header = FALSE,
486
                       reshape = FALSE,
487
                       params = list(),
488
                       ...) {
489

490
491
492
493
494
495
496
497
498
499
      additional_params <- list(...)
      if (length(additional_params) > 0L) {
        warning(paste0(
          "Booster$predict(): Found the following passed through '...': "
          , paste(names(additional_params), collapse = ", ")
          , ". These will be used, but in future releases of lightgbm, this warning will become an error. "
          , "Add these to 'params' instead. See ?predict.lgb.Booster for documentation on how to call this function."
        ))
      }

500
501
502
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
503

504
505
506
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
507

508
      # Predict on new data
509
      params <- utils::modifyList(params, additional_params)
510
511
512
513
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
      )
514
515
      return(
        predictor$predict(
516
517
518
519
520
521
522
523
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
          , reshape = reshape
524
        )
525
      )
526

527
    },
528

529
530
    # Transform into predictor
    to_predictor = function() {
531
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
532
    },
533

534
    # Used for save
535
    raw = NA,
536

537
    # Save model to temporary file for in-memory saving
538
    save = function() {
539

540
      # Overwrite model in object
541
      self$raw <- self$save_model_to_string(NULL)
542

543
544
      return(invisible(NULL))

545
    }
546

Guolin Ke's avatar
Guolin Ke committed
547
548
  ),
  private = list(
549
550
551
552
553
554
555
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
556
557
    num_class = 1L,
    num_dataset = 0L,
558
559
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
560
    higher_better_inner_eval = NULL,
561
    set_objective_to_none = FALSE,
562
    train_set_version = 0L,
563
564
    # Predict data
    inner_predict = function(idx) {
565

566
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
567
      data_name <- private$name_train_set
568

569
570
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
571
      }
572

573
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
574
575
576
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
577

578
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
579
      if (is.null(private$predict_buffer[[data_name]])) {
580

581
        # Store predictions
582
        npred <- 0L
583
584
        .Call(
          LGBM_BoosterGetNumPredict_R
585
          , private$handle
586
          , as.integer(idx - 1L)
587
          , npred
588
        )
589
        private$predict_buffer[[data_name]] <- numeric(npred)
590

Guolin Ke's avatar
Guolin Ke committed
591
      }
592

593
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
594
      if (!private$is_predicted_cur_iter[[idx]]) {
595

596
        # Use buffer
597
598
        .Call(
          LGBM_BoosterGetPredict_R
599
          , private$handle
600
          , as.integer(idx - 1L)
601
          , private$predict_buffer[[data_name]]
602
        )
Guolin Ke's avatar
Guolin Ke committed
603
604
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
605

606
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
607
    },
608

609
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
610
    get_eval_info = function() {
611

Guolin Ke's avatar
Guolin Ke committed
612
      if (is.null(private$eval_names)) {
613
        eval_names <- .Call(
614
          LGBM_BoosterGetEvalNames_R
615
616
          , private$handle
        )
617

618
        if (length(eval_names) > 0L) {
619

620
          # Parse and store privately names
621
          private$eval_names <- eval_names
622
623
624

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
625
          metric_names <- gsub("@.*", "", eval_names)
626
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
627

Guolin Ke's avatar
Guolin Ke committed
628
        }
629

Guolin Ke's avatar
Guolin Ke committed
630
      }
631

632
      return(private$eval_names)
633

Guolin Ke's avatar
Guolin Ke committed
634
    },
635

Guolin Ke's avatar
Guolin Ke committed
636
    inner_eval = function(data_name, data_idx, feval = NULL) {
637

638
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
639
640
641
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
642

Guolin Ke's avatar
Guolin Ke committed
643
      private$get_eval_info()
644

Guolin Ke's avatar
Guolin Ke committed
645
      ret <- list()
646

647
      if (length(private$eval_names) > 0L) {
648

649
650
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
651
652
        .Call(
          LGBM_BoosterGetEval_R
653
          , private$handle
654
          , as.integer(data_idx - 1L)
655
          , tmp_vals
656
        )
657

658
        for (i in seq_along(private$eval_names)) {
659

660
661
662
663
664
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
665
          res$higher_better <- private$higher_better_inner_eval[i]
666
          ret <- append(ret, list(res))
667

Guolin Ke's avatar
Guolin Ke committed
668
        }
669

Guolin Ke's avatar
Guolin Ke committed
670
      }
671

672
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
673
      if (!is.null(feval)) {
674

675
        # Check if evaluation metric is a function
676
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
677
678
          stop("lgb.Booster.eval: feval should be a function")
        }
679

Guolin Ke's avatar
Guolin Ke committed
680
        data <- private$train_set
681

682
        # Check if data to assess is existing differently
683
684
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
685
        }
686

687
        # Perform function evaluation
688
        res <- feval(private$inner_predict(data_idx), data)
689

690
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
691
          stop("lgb.Booster.eval: custom eval function should return a
692
693
            list with attribute (name, value, higher_better)");
        }
694

695
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
696
        res$data_name <- data_name
697
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
698
      }
699

700
      return(ret)
701

Guolin Ke's avatar
Guolin Ke committed
702
    }
703

Guolin Ke's avatar
Guolin Ke committed
704
705
706
  )
)

707
708
709
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
710
#' @param object Object of class \code{lgb.Booster}
711
712
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or
#'             a character representing a path to a text file (CSV, TSV, or LibSVM)
713
714
715
716
717
718
719
720
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
721
#' @param rawscore whether the prediction should be returned in the for of original untransformed
722
723
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
724
#' @param predleaf whether predict leaf index instead.
725
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
726
#' @param header only used for prediction for text file. True if text file has header
727
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
728
#'                prediction outputs per case.
729
730
731
732
733
#' @param params a list of additional named parameters. See
#'               \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#'               the "Predict Parameters" section of the documentation} for a list of parameters and
#'               valid values.
#' @param ... Additional prediction parameters. NOTE: deprecated as of v3.3.0. Use \code{params} instead.
734
735
736
737
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
738
#'
739
740
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
741
#'
Guolin Ke's avatar
Guolin Ke committed
742
#' @examples
743
#' \donttest{
744
745
746
747
748
749
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
750
751
752
753
754
755
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
756
#' valids <- list(test = dtest)
757
758
759
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
760
#'   , nrounds = 5L
761
762
#'   , valids = valids
#' )
763
#' preds <- predict(model, test$data)
764
765
#'
#' # pass other prediction parameters
766
#' preds <- predict(
767
768
769
770
771
772
#'     model,
#'     test$data,
#'     params = list(
#'         predict_disable_shape_check = TRUE
#'    )
#' )
773
#' }
774
#' @importFrom utils modifyList
Guolin Ke's avatar
Guolin Ke committed
775
#' @export
James Lamb's avatar
James Lamb committed
776
777
predict.lgb.Booster <- function(object,
                                data,
778
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
779
780
781
782
783
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
784
                                reshape = FALSE,
785
                                params = list(),
James Lamb's avatar
James Lamb committed
786
                                ...) {
787

788
  if (!lgb.is.Booster(x = object)) {
789
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
790
  }
791

792
793
794
795
796
797
798
799
800
801
  additional_params <- list(...)
  if (length(additional_params) > 0L) {
    warning(paste0(
      "predict.lgb.Booster: Found the following passed through '...': "
      , paste(names(additional_params), collapse = ", ")
      , ". These will be used, but in future releases of lightgbm, this warning will become an error. "
      , "Add these to 'params' instead. See ?predict.lgb.Booster for documentation on how to call this function."
    ))
  }

802
803
804
  return(
    object$predict(
      data = data
805
806
807
808
809
810
811
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
      , reshape = reshape
812
      , params = utils::modifyList(params, additional_params)
813
    )
814
  )
Guolin Ke's avatar
Guolin Ke committed
815
816
}

817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `x`, returned as invisible.
#' @export
print.lgb.Booster <- function(x, ...) {
  # nolint start
  handle <- x$.__enclos_env__$private$handle
  handle_is_null <- lgb.is.null.handle(handle)

  if (!handle_is_null) {
    ntrees <- x$current_iter()
    if (ntrees == 1L) {
      cat("LightGBM Model (1 tree)\n")
    } else {
      cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
    }
  } else {
    cat("LightGBM Model\n")
  }

  if (!handle_is_null) {
    obj <- x$params$objective
    if (obj == "none") {
      obj <- "custom"
    }
    if (x$.__enclos_env__$private$num_class == 1L) {
      cat(sprintf("Objective: %s\n", obj))
    } else {
      cat(sprintf("Objective: %s (%d classes)\n"
          , obj
          , x$.__enclos_env__$private$num_class))
    }
  } else {
    cat("(Booster handle is invalid)\n")
  }

  if (!handle_is_null) {
    ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
    cat(sprintf("Fitted to dataset with %d columns\n", ncols))
  }
  # nolint end

  return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `object`, returned as invisible.
#' @export
summary.lgb.Booster <- function(object, ...) {
  print(object)
}

876
877
#' @name lgb.load
#' @title Load LightGBM model
878
879
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
880
#' @param filename path of model file
881
#' @param model_str a str containing the model
882
#'
883
#' @return lgb.Booster
884
#'
Guolin Ke's avatar
Guolin Ke committed
885
#' @examples
886
#' \donttest{
887
888
889
890
891
892
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
893
894
895
896
897
898
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
899
#' valids <- list(test = dtest)
900
901
902
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
903
#'   , nrounds = 5L
904
#'   , valids = valids
905
#'   , early_stopping_rounds = 3L
906
#' )
907
908
909
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
910
911
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
912
#' }
Guolin Ke's avatar
Guolin Ke committed
913
#' @export
914
lgb.load <- function(filename = NULL, model_str = NULL) {
915

916
917
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
918

919
920
921
922
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
923
    filename <- path.expand(filename)
924
925
926
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
927
928
    return(invisible(Booster$new(modelfile = filename)))
  }
929

930
931
932
933
  if (model_str_provided) {
    if (!is.character(model_str)) {
      stop("lgb.load: model_str should be character")
    }
934
935
    return(invisible(Booster$new(model_str = model_str)))
  }
936

937
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
938
939
}

940
941
942
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
943
944
945
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
946
#'
947
#' @return lgb.Booster
948
#'
Guolin Ke's avatar
Guolin Ke committed
949
#' @examples
950
#' \donttest{
951
952
953
954
955
956
957
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
958
959
960
961
962
963
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
964
#' valids <- list(test = dtest)
965
966
967
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
968
#'   , nrounds = 10L
969
#'   , valids = valids
970
#'   , early_stopping_rounds = 5L
971
#' )
972
#' lgb.save(model, tempfile(fileext = ".txt"))
973
#' }
Guolin Ke's avatar
Guolin Ke committed
974
#' @export
975
lgb.save <- function(booster, filename, num_iteration = NULL) {
976

977
  if (!lgb.is.Booster(x = booster)) {
978
979
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
980

981
982
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
983
  }
984
  filename <- path.expand(filename)
985

986
  # Store booster
987
988
989
990
991
992
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
993

Guolin Ke's avatar
Guolin Ke committed
994
995
}

996
997
998
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
999
1000
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1001
#'
Guolin Ke's avatar
Guolin Ke committed
1002
#' @return json format of model
1003
#'
Guolin Ke's avatar
Guolin Ke committed
1004
#' @examples
1005
#' \donttest{
1006
1007
1008
1009
1010
1011
1012
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1013
1014
1015
1016
1017
1018
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1019
#' valids <- list(test = dtest)
1020
1021
1022
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1023
#'   , nrounds = 10L
1024
#'   , valids = valids
1025
#'   , early_stopping_rounds = 5L
1026
#' )
1027
#' json_model <- lgb.dump(model)
1028
#' }
Guolin Ke's avatar
Guolin Ke committed
1029
#' @export
1030
lgb.dump <- function(booster, num_iteration = NULL) {
1031

1032
  if (!lgb.is.Booster(x = booster)) {
1033
1034
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1035

1036
  # Return booster at requested iteration
1037
  return(booster$dump_model(num_iteration =  num_iteration))
1038

Guolin Ke's avatar
Guolin Ke committed
1039
1040
}

1041
1042
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
1043
1044
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
1045
#' @param booster Object of class \code{lgb.Booster}
1046
1047
1048
1049
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
1050
#' @param is_err TRUE will return evaluation error instead
1051
#'
1052
#' @return numeric vector of evaluation result
1053
#'
1054
#' @examples
1055
#' \donttest{
1056
#' # train a regression model
1057
1058
1059
1060
1061
1062
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1063
1064
1065
1066
1067
1068
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1069
#' valids <- list(test = dtest)
1070
1071
1072
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1073
#'   , nrounds = 5L
1074
1075
#'   , valids = valids
#' )
1076
1077
1078
1079
1080
1081
1082
1083
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1084
#' lgb.get.eval.result(model, "test", "l2")
1085
#' }
Guolin Ke's avatar
Guolin Ke committed
1086
#' @export
1087
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1088

1089
  if (!lgb.is.Booster(x = booster)) {
1090
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1091
  }
1092

1093
1094
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1095
  }
1096

1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1107
  }
1108

1109
  # Check if evaluation result is existing
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1121
1122
    stop("lgb.get.eval.result: wrong eval name")
  }
1123

1124
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1125

1126
  # Check if error is requested
1127
  if (is_err) {
1128
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1129
  }
1130

1131
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1132
1133
    return(as.numeric(result))
  }
1134

1135
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1136
  iters <- as.integer(iters)
1137
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1138
  iters <- iters - delta
1139

1140
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1141
}