lgb.Booster.R 31.5 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom R6 R6Class
2
#' @importFrom utils modifyList
James Lamb's avatar
James Lamb committed
3
Booster <- R6::R6Class(
4
  classname = "lgb.Booster",
5
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
6
  public = list(
7

8
    best_iter = -1L,
9
    best_score = NA_real_,
10
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
11
    record_evals = list(),
12

13
14
    # Finalize will free up the handles
    finalize = function() {
15
16
17
18
19
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
20
      return(invisible(NULL))
21
    },
22

23
24
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
25
26
                          train_set = NULL,
                          modelfile = NULL,
27
                          model_str = NULL) {
28

29
      # Create parameters and handle
30
      handle <- NULL
31

32
33
      # Attempts to create a handle for the dataset
      try({
34

35
36
37
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
38
          if (!lgb.is.Dataset(train_set)) {
39
40
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
41
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
42
          params <- utils::modifyList(params, train_set$get_params())
43
          params_str <- lgb.params2str(params = params)
44
          # Store booster handle
45
          handle <- .Call(
46
            LGBM_BoosterCreate_R
47
            , train_set_handle
48
49
            , params_str
          )
50

51
52
          # Create private booster information
          private$train_set <- train_set
53
          private$train_set_version <- train_set$.__enclos_env__$private$version
54
          private$num_dataset <- 1L
55
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
56

57
58
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
59

60
            # Merge booster
61
62
            .Call(
              LGBM_BoosterMerge_R
63
64
65
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
66

67
          }
68

69
70
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
71

72
        } else if (!is.null(modelfile)) {
73

74
75
76
77
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
78

79
80
          modelfile <- path.expand(modelfile)

81
          # Create booster from model
82
          handle <- .Call(
83
            LGBM_BoosterCreateFromModelfile_R
84
            , modelfile
85
          )
86

87
        } else if (!is.null(model_str)) {
88

89
90
91
          # Do we have a model_str as character/raw?
          if (!is.raw(model_str) && !is.character(model_str)) {
            stop("lgb.Booster: Can only use a character/raw vector as model_str")
92
          }
93

94
          # Create booster from model
95
          handle <- .Call(
96
            LGBM_BoosterLoadModelFromString_R
97
            , model_str
98
          )
99

100
        } else {
101

102
          # Booster non existent
103
104
105
106
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
107

108
        }
109

110
      })
111

112
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
113
      if (isTRUE(lgb.is.null.handle(x = handle))) {
114

Guolin Ke's avatar
Guolin Ke committed
115
        stop("lgb.Booster: cannot create Booster handle")
116

Guolin Ke's avatar
Guolin Ke committed
117
      } else {
118

Guolin Ke's avatar
Guolin Ke committed
119
120
121
122
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
123
124
        .Call(
          LGBM_BoosterGetNumClasses_R
125
          , private$handle
126
          , private$num_class
127
        )
128

Guolin Ke's avatar
Guolin Ke committed
129
      }
130

131
132
      self$params <- params

133
134
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
135
    },
136

137
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
138
    set_train_data_name = function(name) {
139

140
      # Set name
Guolin Ke's avatar
Guolin Ke committed
141
      private$name_train_set <- name
142
      return(invisible(self))
143

Guolin Ke's avatar
Guolin Ke committed
144
    },
145

146
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
147
    add_valid = function(data, name) {
148

149
      if (!lgb.is.Dataset(data)) {
150
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
151
      }
152

Guolin Ke's avatar
Guolin Ke committed
153
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
154
155
156
157
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
158
      }
159

160
161
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
162
      }
163

164
      # Add validation data to booster
165
166
      .Call(
        LGBM_BoosterAddValidData_R
167
168
169
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
170

171
172
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
173
      private$num_dataset <- private$num_dataset + 1L
174
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
175

176
      return(invisible(self))
177

Guolin Ke's avatar
Guolin Ke committed
178
    },
179

180
    reset_parameter = function(params) {
181

182
      if (methods::is(self$params, "list")) {
183
        params <- utils::modifyList(self$params, params)
184
185
      }

186
      params_str <- lgb.params2str(params = params)
187

188
189
      self$restore_handle()

190
191
      .Call(
        LGBM_BoosterResetParameter_R
192
193
194
        , private$handle
        , params_str
      )
195
      self$params <- params
196

197
      return(invisible(self))
198

Guolin Ke's avatar
Guolin Ke committed
199
    },
200

201
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
202
    update = function(train_set = NULL, fobj = NULL) {
203

204
205
206
207
208
209
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

Guolin Ke's avatar
Guolin Ke committed
210
      if (!is.null(train_set)) {
211

212
        if (!lgb.is.Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
213
214
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
215

Guolin Ke's avatar
Guolin Ke committed
216
        if (!identical(train_set$predictor, private$init_predictor)) {
217
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
218
        }
219

220
221
        .Call(
          LGBM_BoosterResetTrainingData_R
222
223
224
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
225

226
        private$train_set <- train_set
227
        private$train_set_version <- train_set$.__enclos_env__$private$version
228

Guolin Ke's avatar
Guolin Ke committed
229
      }
230

231
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
232
      if (is.null(fobj)) {
233
234
235
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
236
        # Boost iteration from known objective
237
238
        .Call(
          LGBM_BoosterUpdateOneIter_R
239
240
          , private$handle
        )
241

Guolin Ke's avatar
Guolin Ke committed
242
      } else {
243

244
245
246
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
247
        if (!private$set_objective_to_none) {
248
          self$reset_parameter(params = list(objective = "none"))
249
          private$set_objective_to_none <- TRUE
250
        }
251
        # Perform objective calculation
252
        gpair <- fobj(private$inner_predict(1L), private$train_set)
253

254
        # Check for gradient and hessian as list
255
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
256
          stop("lgb.Booster.update: custom objective should
257
258
            return a list with attributes (hess, grad)")
        }
259

260
        # Return custom boosting gradient/hessian
261
262
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
263
264
265
266
267
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
268

Guolin Ke's avatar
Guolin Ke committed
269
      }
270

271
      # Loop through each iteration
272
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
273
274
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
275

276
      return(invisible(self))
277

Guolin Ke's avatar
Guolin Ke committed
278
    },
279

280
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
281
    rollback_one_iter = function() {
282

283
284
      self$restore_handle()

285
286
      .Call(
        LGBM_BoosterRollbackOneIter_R
287
288
        , private$handle
      )
289

290
      # Loop through each iteration
291
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
292
293
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
294

295
      return(invisible(self))
296

Guolin Ke's avatar
Guolin Ke committed
297
    },
298

299
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
300
    current_iter = function() {
301

302
303
      self$restore_handle()

304
      cur_iter <- 0L
305
306
307
308
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
309
      )
310
      return(cur_iter)
311

Guolin Ke's avatar
Guolin Ke committed
312
    },
313

314
    # Get upper bound
315
    upper_bound = function() {
316

317
318
      self$restore_handle()

319
      upper_bound <- 0.0
320
321
322
323
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
324
      )
325
      return(upper_bound)
326
327
328
329

    },

    # Get lower bound
330
    lower_bound = function() {
331

332
333
      self$restore_handle()

334
      lower_bound <- 0.0
335
336
337
338
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
339
      )
340
      return(lower_bound)
341
342
343

    },

344
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
345
    eval = function(data, name, feval = NULL) {
346

347
      if (!lgb.is.Dataset(data)) {
348
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
349
      }
350

351
      # Check for identical data
352
      data_idx <- 0L
353
      if (identical(data, private$train_set)) {
354
        data_idx <- 1L
355
      } else {
356

357
        # Check for validation data
358
        if (length(private$valid_sets) > 0L) {
359

360
          for (i in seq_along(private$valid_sets)) {
361

362
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
363
            if (identical(data, private$valid_sets[[i]])) {
364

365
              # Found identical data, skip
366
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
367
              break
368

Guolin Ke's avatar
Guolin Ke committed
369
            }
370

Guolin Ke's avatar
Guolin Ke committed
371
          }
372

Guolin Ke's avatar
Guolin Ke committed
373
        }
374

Guolin Ke's avatar
Guolin Ke committed
375
      }
376

377
      # Check if evaluation was not done
378
      if (data_idx == 0L) {
379

380
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
381
382
        self$add_valid(data, name)
        data_idx <- private$num_dataset
383

Guolin Ke's avatar
Guolin Ke committed
384
      }
385

386
      # Evaluate data
387
388
389
390
391
392
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
393
      )
394

Guolin Ke's avatar
Guolin Ke committed
395
    },
396

397
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
398
    eval_train = function(feval = NULL) {
399
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
400
    },
401

402
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
403
    eval_valid = function(feval = NULL) {
404

405
      ret <- list()
406

407
      if (length(private$valid_sets) <= 0L) {
408
409
        return(ret)
      }
410

411
      for (i in seq_along(private$valid_sets)) {
412
413
        ret <- append(
          x = ret
414
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
415
        )
Guolin Ke's avatar
Guolin Ke committed
416
      }
417

418
      return(ret)
419

Guolin Ke's avatar
Guolin Ke committed
420
    },
421

422
    # Save model
423
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
424

425
426
      self$restore_handle()

427
428
429
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
430

431
432
      filename <- path.expand(filename)

433
434
      .Call(
        LGBM_BoosterSaveModel_R
435
436
        , private$handle
        , as.integer(num_iteration)
437
        , as.integer(feature_importance_type)
438
        , filename
439
      )
440

441
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
442
    },
443

444
445
446
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {

      self$restore_handle()
447

448
449
450
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
451

452
      model_str <- .Call(
453
          LGBM_BoosterSaveModelToString_R
454
455
456
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
457
458
      )

459
460
461
462
      if (as_char) {
        model_str <- rawToChar(model_str)
      }

463
      return(model_str)
464

465
    },
466

467
    # Dump model in memory
468
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
469

470
471
      self$restore_handle()

472
473
474
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
475

476
      model_str <- .Call(
477
478
479
480
481
482
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
      )

483
      return(model_str)
484

Guolin Ke's avatar
Guolin Ke committed
485
    },
486

487
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
488
    predict = function(data,
489
                       start_iteration = NULL,
490
491
492
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
493
                       predcontrib = FALSE,
494
                       header = FALSE,
495
                       reshape = FALSE,
496
                       params = list()) {
497

498
499
      self$restore_handle()

500
501
502
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
503

504
505
506
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
507

508
      # Predict on new data
509
510
511
512
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
      )
513
514
      return(
        predictor$predict(
515
516
517
518
519
520
521
522
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
          , reshape = reshape
523
        )
524
      )
525

526
    },
527

528
529
    # Transform into predictor
    to_predictor = function() {
530
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
531
    },
532

533
534
    # Used for serialization
    raw = NULL,
535

536
537
538
539
540
541
    # Store serialized raw bytes in model object
    save_raw = function() {
      if (is.null(self$raw)) {
        self$raw <- self$save_model_to_string(NULL, as_char = FALSE)
      }
      return(invisible(NULL))
542

543
    },
544

545
546
    drop_raw = function() {
      self$raw <- NULL
547
      return(invisible(NULL))
548
    },
549

550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
    check_null_handle = function() {
      return(lgb.is.null.handle(private$handle))
    },

    restore_handle = function() {
      if (self$check_null_handle()) {
        if (is.null(self$raw)) {
          .Call(LGBM_NullBoosterHandleError_R)
        }
        private$handle <- .Call(LGBM_BoosterLoadModelFromString_R, self$raw)
      }
      return(invisible(NULL))
    },

    get_handle = function() {
      return(private$handle)
566
    }
567

Guolin Ke's avatar
Guolin Ke committed
568
569
  ),
  private = list(
570
571
572
573
574
575
576
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
577
578
    num_class = 1L,
    num_dataset = 0L,
579
580
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
581
    higher_better_inner_eval = NULL,
582
    set_objective_to_none = FALSE,
583
    train_set_version = 0L,
584
585
    # Predict data
    inner_predict = function(idx) {
586

587
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
588
      data_name <- private$name_train_set
589

590
591
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
592
      }
593

594
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
595
596
597
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
598

599
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
600
      if (is.null(private$predict_buffer[[data_name]])) {
601

602
        # Store predictions
603
        npred <- 0L
604
605
        .Call(
          LGBM_BoosterGetNumPredict_R
606
          , private$handle
607
          , as.integer(idx - 1L)
608
          , npred
609
        )
610
        private$predict_buffer[[data_name]] <- numeric(npred)
611

Guolin Ke's avatar
Guolin Ke committed
612
      }
613

614
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
615
      if (!private$is_predicted_cur_iter[[idx]]) {
616

617
        # Use buffer
618
619
        .Call(
          LGBM_BoosterGetPredict_R
620
          , private$handle
621
          , as.integer(idx - 1L)
622
          , private$predict_buffer[[data_name]]
623
        )
Guolin Ke's avatar
Guolin Ke committed
624
625
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
626

627
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
628
    },
629

630
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
631
    get_eval_info = function() {
632

Guolin Ke's avatar
Guolin Ke committed
633
      if (is.null(private$eval_names)) {
634
        eval_names <- .Call(
635
          LGBM_BoosterGetEvalNames_R
636
637
          , private$handle
        )
638

639
        if (length(eval_names) > 0L) {
640

641
          # Parse and store privately names
642
          private$eval_names <- eval_names
643
644
645

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
646
          metric_names <- gsub("@.*", "", eval_names)
647
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
648

Guolin Ke's avatar
Guolin Ke committed
649
        }
650

Guolin Ke's avatar
Guolin Ke committed
651
      }
652

653
      return(private$eval_names)
654

Guolin Ke's avatar
Guolin Ke committed
655
    },
656

Guolin Ke's avatar
Guolin Ke committed
657
    inner_eval = function(data_name, data_idx, feval = NULL) {
658

659
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
660
661
662
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
663

664
665
      self$restore_handle()

Guolin Ke's avatar
Guolin Ke committed
666
      private$get_eval_info()
667

Guolin Ke's avatar
Guolin Ke committed
668
      ret <- list()
669

670
      if (length(private$eval_names) > 0L) {
671

672
673
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
674
675
        .Call(
          LGBM_BoosterGetEval_R
676
          , private$handle
677
          , as.integer(data_idx - 1L)
678
          , tmp_vals
679
        )
680

681
        for (i in seq_along(private$eval_names)) {
682

683
684
685
686
687
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
688
          res$higher_better <- private$higher_better_inner_eval[i]
689
          ret <- append(ret, list(res))
690

Guolin Ke's avatar
Guolin Ke committed
691
        }
692

Guolin Ke's avatar
Guolin Ke committed
693
      }
694

695
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
696
      if (!is.null(feval)) {
697

698
        # Check if evaluation metric is a function
699
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
700
701
          stop("lgb.Booster.eval: feval should be a function")
        }
702

Guolin Ke's avatar
Guolin Ke committed
703
        data <- private$train_set
704

705
        # Check if data to assess is existing differently
706
707
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
708
        }
709

710
        # Perform function evaluation
711
        res <- feval(private$inner_predict(data_idx), data)
712

713
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
714
          stop("lgb.Booster.eval: custom eval function should return a
715
716
            list with attribute (name, value, higher_better)");
        }
717

718
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
719
        res$data_name <- data_name
720
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
721
      }
722

723
      return(ret)
724

Guolin Ke's avatar
Guolin Ke committed
725
    }
726

Guolin Ke's avatar
Guolin Ke committed
727
728
729
  )
)

730
731
732
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
733
#' @param object Object of class \code{lgb.Booster}
734
735
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or
#'             a character representing a path to a text file (CSV, TSV, or LibSVM)
736
737
738
739
740
741
742
743
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
744
#' @param rawscore whether the prediction should be returned in the for of original untransformed
745
746
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
747
#' @param predleaf whether predict leaf index instead.
748
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
749
#' @param header only used for prediction for text file. True if text file has header
750
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
751
#'                prediction outputs per case.
752
753
754
755
#' @param params a list of additional named parameters. See
#'               \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#'               the "Predict Parameters" section of the documentation} for a list of parameters and
#'               valid values.
756
#' @param ... ignored
757
758
759
760
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
761
#'
762
763
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
764
#'
Guolin Ke's avatar
Guolin Ke committed
765
#' @examples
766
#' \donttest{
767
768
769
770
771
772
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
773
774
775
776
777
778
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
779
#' valids <- list(test = dtest)
780
781
782
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
783
#'   , nrounds = 5L
784
785
#'   , valids = valids
#' )
786
#' preds <- predict(model, test$data)
787
788
#'
#' # pass other prediction parameters
789
#' preds <- predict(
790
791
792
793
794
795
#'     model,
#'     test$data,
#'     params = list(
#'         predict_disable_shape_check = TRUE
#'    )
#' )
796
#' }
797
#' @importFrom utils modifyList
Guolin Ke's avatar
Guolin Ke committed
798
#' @export
James Lamb's avatar
James Lamb committed
799
800
predict.lgb.Booster <- function(object,
                                data,
801
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
802
803
804
805
806
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
807
                                reshape = FALSE,
808
                                params = list(),
James Lamb's avatar
James Lamb committed
809
                                ...) {
810

811
  if (!lgb.is.Booster(x = object)) {
812
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
813
  }
814

815
816
817
818
819
  additional_params <- list(...)
  if (length(additional_params) > 0L) {
    warning(paste0(
      "predict.lgb.Booster: Found the following passed through '...': "
      , paste(names(additional_params), collapse = ", ")
820
      , ". These are ignored. Use argument 'params' instead."
821
822
823
    ))
  }

824
825
826
  return(
    object$predict(
      data = data
827
828
829
830
831
832
833
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
      , reshape = reshape
834
      , params = params
835
    )
836
  )
Guolin Ke's avatar
Guolin Ke committed
837
838
}

839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `x`, returned as invisible.
#' @export
print.lgb.Booster <- function(x, ...) {
  # nolint start
  handle <- x$.__enclos_env__$private$handle
  handle_is_null <- lgb.is.null.handle(handle)

  if (!handle_is_null) {
    ntrees <- x$current_iter()
    if (ntrees == 1L) {
      cat("LightGBM Model (1 tree)\n")
    } else {
      cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
    }
  } else {
    cat("LightGBM Model\n")
  }

  if (!handle_is_null) {
    obj <- x$params$objective
    if (obj == "none") {
      obj <- "custom"
    }
867
868
    num_class <- x$.__enclos_env__$private$num_class
    if (num_class == 1L) {
869
870
871
872
      cat(sprintf("Objective: %s\n", obj))
    } else {
      cat(sprintf("Objective: %s (%d classes)\n"
          , obj
873
          , num_class))
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
    }
  } else {
    cat("(Booster handle is invalid)\n")
  }

  if (!handle_is_null) {
    ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
    cat(sprintf("Fitted to dataset with %d columns\n", ncols))
  }
  # nolint end

  return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `object`, returned as invisible.
#' @export
summary.lgb.Booster <- function(object, ...) {
  print(object)
}

899
900
#' @name lgb.load
#' @title Load LightGBM model
901
902
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
903
#' @param filename path of model file
904
#' @param model_str a str containing the model (as a `character` or `raw` vector)
905
#'
906
#' @return lgb.Booster
907
#'
Guolin Ke's avatar
Guolin Ke committed
908
#' @examples
909
#' \donttest{
910
911
912
913
914
915
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
916
917
918
919
920
921
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
922
#' valids <- list(test = dtest)
923
924
925
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
926
#'   , nrounds = 5L
927
#'   , valids = valids
928
#'   , early_stopping_rounds = 3L
929
#' )
930
931
932
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
933
934
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
935
#' }
Guolin Ke's avatar
Guolin Ke committed
936
#' @export
937
lgb.load <- function(filename = NULL, model_str = NULL) {
938

939
940
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
941

942
943
944
945
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
946
    filename <- path.expand(filename)
947
948
949
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
950
951
    return(invisible(Booster$new(modelfile = filename)))
  }
952

953
  if (model_str_provided) {
954
955
    if (!is.raw(model_str) && !is.character(model_str)) {
      stop("lgb.load: model_str should be a character/raw vector")
956
    }
957
958
    return(invisible(Booster$new(model_str = model_str)))
  }
959

960
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
961
962
}

963
964
965
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
966
967
968
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
969
#'
970
#' @return lgb.Booster
971
#'
Guolin Ke's avatar
Guolin Ke committed
972
#' @examples
973
#' \donttest{
974
975
976
977
978
979
980
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
981
982
983
984
985
986
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
987
#' valids <- list(test = dtest)
988
989
990
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
991
#'   , nrounds = 10L
992
#'   , valids = valids
993
#'   , early_stopping_rounds = 5L
994
#' )
995
#' lgb.save(model, tempfile(fileext = ".txt"))
996
#' }
Guolin Ke's avatar
Guolin Ke committed
997
#' @export
998
lgb.save <- function(booster, filename, num_iteration = NULL) {
999

1000
  if (!lgb.is.Booster(x = booster)) {
1001
1002
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1003

1004
1005
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
1006
  }
1007
  filename <- path.expand(filename)
1008

1009
  # Store booster
1010
1011
1012
1013
1014
1015
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
1016

Guolin Ke's avatar
Guolin Ke committed
1017
1018
}

1019
1020
1021
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
1022
1023
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1024
#'
Guolin Ke's avatar
Guolin Ke committed
1025
#' @return json format of model
1026
#'
Guolin Ke's avatar
Guolin Ke committed
1027
#' @examples
1028
#' \donttest{
1029
1030
1031
1032
1033
1034
1035
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1036
1037
1038
1039
1040
1041
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1042
#' valids <- list(test = dtest)
1043
1044
1045
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1046
#'   , nrounds = 10L
1047
#'   , valids = valids
1048
#'   , early_stopping_rounds = 5L
1049
#' )
1050
#' json_model <- lgb.dump(model)
1051
#' }
Guolin Ke's avatar
Guolin Ke committed
1052
#' @export
1053
lgb.dump <- function(booster, num_iteration = NULL) {
1054

1055
  if (!lgb.is.Booster(x = booster)) {
1056
1057
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1058

1059
  # Return booster at requested iteration
1060
  return(booster$dump_model(num_iteration =  num_iteration))
1061

Guolin Ke's avatar
Guolin Ke committed
1062
1063
}

1064
1065
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
1066
1067
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
1068
#' @param booster Object of class \code{lgb.Booster}
1069
1070
1071
1072
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
1073
#' @param is_err TRUE will return evaluation error instead
1074
#'
1075
#' @return numeric vector of evaluation result
1076
#'
1077
#' @examples
1078
#' \donttest{
1079
#' # train a regression model
1080
1081
1082
1083
1084
1085
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1086
1087
1088
1089
1090
1091
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1092
#' valids <- list(test = dtest)
1093
1094
1095
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1096
#'   , nrounds = 5L
1097
1098
#'   , valids = valids
#' )
1099
1100
1101
1102
1103
1104
1105
1106
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1107
#' lgb.get.eval.result(model, "test", "l2")
1108
#' }
Guolin Ke's avatar
Guolin Ke committed
1109
#' @export
1110
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1111

1112
  if (!lgb.is.Booster(x = booster)) {
1113
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1114
  }
1115

1116
1117
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1118
  }
1119

1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1130
  }
1131

1132
  # Check if evaluation result is existing
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1144
1145
    stop("lgb.get.eval.result: wrong eval name")
  }
1146

1147
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1148

1149
  # Check if error is requested
1150
  if (is_err) {
1151
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1152
  }
1153

1154
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1155
1156
    return(as.numeric(result))
  }
1157

1158
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1159
  iters <- as.integer(iters)
1160
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1161
  iters <- iters - delta
1162

1163
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1164
}