lgb.Booster.R 31.5 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom R6 R6Class
2
#' @importFrom utils modifyList
James Lamb's avatar
James Lamb committed
3
Booster <- R6::R6Class(
4
  classname = "lgb.Booster",
5
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
6
  public = list(
7

8
    best_iter = -1L,
9
    best_score = NA_real_,
10
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
11
    record_evals = list(),
12

13
14
    # Finalize will free up the handles
    finalize = function() {
15
16
17
18
19
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
20
      return(invisible(NULL))
21
    },
22

23
24
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
25
26
                          train_set = NULL,
                          modelfile = NULL,
27
                          model_str = NULL) {
28

29
      # Create parameters and handle
30
      handle <- NULL
31

32
33
      # Attempts to create a handle for the dataset
      try({
34

35
36
37
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
38
          if (!lgb.is.Dataset(train_set)) {
39
40
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
41
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
42
          params <- utils::modifyList(params, train_set$get_params())
43
          params_str <- lgb.params2str(params = params)
44
          # Store booster handle
45
          handle <- .Call(
46
            LGBM_BoosterCreate_R
47
            , train_set_handle
48
49
            , params_str
          )
50

51
52
          # Create private booster information
          private$train_set <- train_set
53
          private$train_set_version <- train_set$.__enclos_env__$private$version
54
          private$num_dataset <- 1L
55
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
56

57
58
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
59

60
            # Merge booster
61
62
            .Call(
              LGBM_BoosterMerge_R
63
64
65
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
66

67
          }
68

69
70
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
71

72
        } else if (!is.null(modelfile)) {
73

74
75
76
77
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
78

79
80
          modelfile <- path.expand(modelfile)

81
          # Create booster from model
82
          handle <- .Call(
83
            LGBM_BoosterCreateFromModelfile_R
84
            , modelfile
85
          )
86

87
        } else if (!is.null(model_str)) {
88

89
90
91
          # Do we have a model_str as character/raw?
          if (!is.raw(model_str) && !is.character(model_str)) {
            stop("lgb.Booster: Can only use a character/raw vector as model_str")
92
          }
93

94
          # Create booster from model
95
          handle <- .Call(
96
            LGBM_BoosterLoadModelFromString_R
97
            , model_str
98
          )
99

100
        } else {
101

102
          # Booster non existent
103
104
105
106
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
107

108
        }
109

110
      })
111

112
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
113
      if (isTRUE(lgb.is.null.handle(x = handle))) {
114

Guolin Ke's avatar
Guolin Ke committed
115
        stop("lgb.Booster: cannot create Booster handle")
116

Guolin Ke's avatar
Guolin Ke committed
117
      } else {
118

Guolin Ke's avatar
Guolin Ke committed
119
120
121
122
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
123
124
        .Call(
          LGBM_BoosterGetNumClasses_R
125
          , private$handle
126
          , private$num_class
127
        )
128

Guolin Ke's avatar
Guolin Ke committed
129
      }
130

131
132
      self$params <- params

133
134
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
135
    },
136

137
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
138
    set_train_data_name = function(name) {
139

140
      # Set name
Guolin Ke's avatar
Guolin Ke committed
141
      private$name_train_set <- name
142
      return(invisible(self))
143

Guolin Ke's avatar
Guolin Ke committed
144
    },
145

146
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
147
    add_valid = function(data, name) {
148

149
      if (!lgb.is.Dataset(data)) {
150
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
151
      }
152

Guolin Ke's avatar
Guolin Ke committed
153
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
154
155
156
157
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
158
      }
159

160
161
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
162
      }
163

164
      # Add validation data to booster
165
166
      .Call(
        LGBM_BoosterAddValidData_R
167
168
169
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
170

171
172
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
173
      private$num_dataset <- private$num_dataset + 1L
174
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
175

176
      return(invisible(self))
177

Guolin Ke's avatar
Guolin Ke committed
178
    },
179

180
    reset_parameter = function(params) {
181

182
      if (methods::is(self$params, "list")) {
183
        params <- utils::modifyList(self$params, params)
184
185
      }

186
      params_str <- lgb.params2str(params = params)
187

188
189
      self$restore_handle()

190
191
      .Call(
        LGBM_BoosterResetParameter_R
192
193
194
        , private$handle
        , params_str
      )
195
      self$params <- params
196

197
      return(invisible(self))
198

Guolin Ke's avatar
Guolin Ke committed
199
    },
200

201
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
202
    update = function(train_set = NULL, fobj = NULL) {
203

204
205
206
207
208
209
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

Guolin Ke's avatar
Guolin Ke committed
210
      if (!is.null(train_set)) {
211

212
        if (!lgb.is.Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
213
214
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
215

Guolin Ke's avatar
Guolin Ke committed
216
        if (!identical(train_set$predictor, private$init_predictor)) {
217
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
218
        }
219

220
221
        .Call(
          LGBM_BoosterResetTrainingData_R
222
223
224
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
225

226
        private$train_set <- train_set
227
        private$train_set_version <- train_set$.__enclos_env__$private$version
228

Guolin Ke's avatar
Guolin Ke committed
229
      }
230

231
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
232
      if (is.null(fobj)) {
233
234
235
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
236
        # Boost iteration from known objective
237
238
        .Call(
          LGBM_BoosterUpdateOneIter_R
239
240
          , private$handle
        )
241

Guolin Ke's avatar
Guolin Ke committed
242
      } else {
243

244
245
246
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
247
        if (!private$set_objective_to_none) {
248
          self$reset_parameter(params = list(objective = "none"))
249
          private$set_objective_to_none <- TRUE
250
        }
251
        # Perform objective calculation
252
        gpair <- fobj(private$inner_predict(1L), private$train_set)
253

254
        # Check for gradient and hessian as list
255
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
256
          stop("lgb.Booster.update: custom objective should
257
258
            return a list with attributes (hess, grad)")
        }
259

260
        # Return custom boosting gradient/hessian
261
262
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
263
264
265
266
267
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
268

Guolin Ke's avatar
Guolin Ke committed
269
      }
270

271
      # Loop through each iteration
272
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
273
274
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
275

276
      return(invisible(self))
277

Guolin Ke's avatar
Guolin Ke committed
278
    },
279

280
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
281
    rollback_one_iter = function() {
282

283
284
      self$restore_handle()

285
286
      .Call(
        LGBM_BoosterRollbackOneIter_R
287
288
        , private$handle
      )
289

290
      # Loop through each iteration
291
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
292
293
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
294

295
      return(invisible(self))
296

Guolin Ke's avatar
Guolin Ke committed
297
    },
298

299
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
300
    current_iter = function() {
301

302
303
      self$restore_handle()

304
      cur_iter <- 0L
305
306
307
308
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
309
      )
310
      return(cur_iter)
311

Guolin Ke's avatar
Guolin Ke committed
312
    },
313

314
    # Get upper bound
315
    upper_bound = function() {
316

317
318
      self$restore_handle()

319
      upper_bound <- 0.0
320
321
322
323
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
324
      )
325
      return(upper_bound)
326
327
328
329

    },

    # Get lower bound
330
    lower_bound = function() {
331

332
333
      self$restore_handle()

334
      lower_bound <- 0.0
335
336
337
338
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
339
      )
340
      return(lower_bound)
341
342
343

    },

344
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
345
    eval = function(data, name, feval = NULL) {
346

347
      if (!lgb.is.Dataset(data)) {
348
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
349
      }
350

351
      # Check for identical data
352
      data_idx <- 0L
353
      if (identical(data, private$train_set)) {
354
        data_idx <- 1L
355
      } else {
356

357
        # Check for validation data
358
        if (length(private$valid_sets) > 0L) {
359

360
          for (i in seq_along(private$valid_sets)) {
361

362
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
363
            if (identical(data, private$valid_sets[[i]])) {
364

365
              # Found identical data, skip
366
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
367
              break
368

Guolin Ke's avatar
Guolin Ke committed
369
            }
370

Guolin Ke's avatar
Guolin Ke committed
371
          }
372

Guolin Ke's avatar
Guolin Ke committed
373
        }
374

Guolin Ke's avatar
Guolin Ke committed
375
      }
376

377
      # Check if evaluation was not done
378
      if (data_idx == 0L) {
379

380
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
381
382
        self$add_valid(data, name)
        data_idx <- private$num_dataset
383

Guolin Ke's avatar
Guolin Ke committed
384
      }
385

386
      # Evaluate data
387
388
389
390
391
392
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
393
      )
394

Guolin Ke's avatar
Guolin Ke committed
395
    },
396

397
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
398
    eval_train = function(feval = NULL) {
399
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
400
    },
401

402
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
403
    eval_valid = function(feval = NULL) {
404

405
      ret <- list()
406

407
      if (length(private$valid_sets) <= 0L) {
408
409
        return(ret)
      }
410

411
      for (i in seq_along(private$valid_sets)) {
412
413
        ret <- append(
          x = ret
414
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
415
        )
Guolin Ke's avatar
Guolin Ke committed
416
      }
417

418
      return(ret)
419

Guolin Ke's avatar
Guolin Ke committed
420
    },
421

422
    # Save model
423
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
424

425
426
      self$restore_handle()

427
428
429
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
430

431
432
      filename <- path.expand(filename)

433
434
      .Call(
        LGBM_BoosterSaveModel_R
435
436
        , private$handle
        , as.integer(num_iteration)
437
        , as.integer(feature_importance_type)
438
        , filename
439
      )
440

441
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
442
    },
443

444
445
446
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {

      self$restore_handle()
447

448
449
450
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
451

452
      model_str <- .Call(
453
          LGBM_BoosterSaveModelToString_R
454
455
456
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
457
458
      )

459
460
461
462
      if (as_char) {
        model_str <- rawToChar(model_str)
      }

463
      return(model_str)
464

465
    },
466

467
    # Dump model in memory
468
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
469

470
471
      self$restore_handle()

472
473
474
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
475

476
      model_str <- .Call(
477
478
479
480
481
482
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
      )

483
      return(model_str)
484

Guolin Ke's avatar
Guolin Ke committed
485
    },
486

487
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
488
    predict = function(data,
489
                       start_iteration = NULL,
490
491
492
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
493
                       predcontrib = FALSE,
494
                       header = FALSE,
495
                       reshape = FALSE,
496
                       params = list()) {
497

498
499
      self$restore_handle()

500
501
502
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
503

504
505
506
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
507

508
      # Predict on new data
509
510
511
512
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
      )
513
514
      return(
        predictor$predict(
515
516
517
518
519
520
521
522
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
          , reshape = reshape
523
        )
524
      )
525

526
    },
527

528
529
    # Transform into predictor
    to_predictor = function() {
530
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
531
    },
532

533
534
    # Used for serialization
    raw = NULL,
535

536
537
538
539
540
541
    # Store serialized raw bytes in model object
    save_raw = function() {
      if (is.null(self$raw)) {
        self$raw <- self$save_model_to_string(NULL, as_char = FALSE)
      }
      return(invisible(NULL))
542

543
    },
544

545
546
    drop_raw = function() {
      self$raw <- NULL
547
      return(invisible(NULL))
548
    },
549

550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
    check_null_handle = function() {
      return(lgb.is.null.handle(private$handle))
    },

    restore_handle = function() {
      if (self$check_null_handle()) {
        if (is.null(self$raw)) {
          .Call(LGBM_NullBoosterHandleError_R)
        }
        private$handle <- .Call(LGBM_BoosterLoadModelFromString_R, self$raw)
      }
      return(invisible(NULL))
    },

    get_handle = function() {
      return(private$handle)
566
    }
567

Guolin Ke's avatar
Guolin Ke committed
568
569
  ),
  private = list(
570
571
572
573
574
575
576
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
577
578
    num_class = 1L,
    num_dataset = 0L,
579
580
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
581
    higher_better_inner_eval = NULL,
582
    set_objective_to_none = FALSE,
583
    train_set_version = 0L,
584
585
    # Predict data
    inner_predict = function(idx) {
586

587
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
588
      data_name <- private$name_train_set
589

590
591
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
592
      }
593

594
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
595
596
597
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
598

599
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
600
      if (is.null(private$predict_buffer[[data_name]])) {
601

602
        # Store predictions
603
        npred <- 0L
604
605
        .Call(
          LGBM_BoosterGetNumPredict_R
606
          , private$handle
607
          , as.integer(idx - 1L)
608
          , npred
609
        )
610
        private$predict_buffer[[data_name]] <- numeric(npred)
611

Guolin Ke's avatar
Guolin Ke committed
612
      }
613

614
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
615
      if (!private$is_predicted_cur_iter[[idx]]) {
616

617
        # Use buffer
618
619
        .Call(
          LGBM_BoosterGetPredict_R
620
          , private$handle
621
          , as.integer(idx - 1L)
622
          , private$predict_buffer[[data_name]]
623
        )
Guolin Ke's avatar
Guolin Ke committed
624
625
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
626

627
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
628
    },
629

630
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
631
    get_eval_info = function() {
632

Guolin Ke's avatar
Guolin Ke committed
633
      if (is.null(private$eval_names)) {
634
        eval_names <- .Call(
635
          LGBM_BoosterGetEvalNames_R
636
637
          , private$handle
        )
638

639
        if (length(eval_names) > 0L) {
640

641
          # Parse and store privately names
642
          private$eval_names <- eval_names
643
644
645

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
646
          metric_names <- gsub("@.*", "", eval_names)
647
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
648

Guolin Ke's avatar
Guolin Ke committed
649
        }
650

Guolin Ke's avatar
Guolin Ke committed
651
      }
652

653
      return(private$eval_names)
654

Guolin Ke's avatar
Guolin Ke committed
655
    },
656

Guolin Ke's avatar
Guolin Ke committed
657
    inner_eval = function(data_name, data_idx, feval = NULL) {
658

659
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
660
661
662
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
663

664
665
      self$restore_handle()

Guolin Ke's avatar
Guolin Ke committed
666
      private$get_eval_info()
667

Guolin Ke's avatar
Guolin Ke committed
668
      ret <- list()
669

670
      if (length(private$eval_names) > 0L) {
671

672
673
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
674
675
        .Call(
          LGBM_BoosterGetEval_R
676
          , private$handle
677
          , as.integer(data_idx - 1L)
678
          , tmp_vals
679
        )
680

681
        for (i in seq_along(private$eval_names)) {
682

683
684
685
686
687
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
688
          res$higher_better <- private$higher_better_inner_eval[i]
689
          ret <- append(ret, list(res))
690

Guolin Ke's avatar
Guolin Ke committed
691
        }
692

Guolin Ke's avatar
Guolin Ke committed
693
      }
694

695
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
696
      if (!is.null(feval)) {
697

698
        # Check if evaluation metric is a function
699
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
700
701
          stop("lgb.Booster.eval: feval should be a function")
        }
702

Guolin Ke's avatar
Guolin Ke committed
703
        data <- private$train_set
704

705
        # Check if data to assess is existing differently
706
707
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
708
        }
709

710
        # Perform function evaluation
711
        res <- feval(private$inner_predict(data_idx), data)
712

713
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
714
          stop("lgb.Booster.eval: custom eval function should return a
715
716
            list with attribute (name, value, higher_better)");
        }
717

718
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
719
        res$data_name <- data_name
720
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
721
      }
722

723
      return(ret)
724

Guolin Ke's avatar
Guolin Ke committed
725
    }
726

Guolin Ke's avatar
Guolin Ke committed
727
728
729
  )
)

730
731
732
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
733
#' @param object Object of class \code{lgb.Booster}
734
735
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or
#'             a character representing a path to a text file (CSV, TSV, or LibSVM)
736
737
738
739
740
741
742
743
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
744
#' @param rawscore whether the prediction should be returned in the for of original untransformed
745
746
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
747
#' @param predleaf whether predict leaf index instead.
748
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
749
#' @param header only used for prediction for text file. True if text file has header
750
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
751
#'                prediction outputs per case.
752
753
754
755
#' @param params a list of additional named parameters. See
#'               \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#'               the "Predict Parameters" section of the documentation} for a list of parameters and
#'               valid values.
756
#' @param ... ignored
757
758
759
760
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
761
#'
762
763
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
764
#'
Guolin Ke's avatar
Guolin Ke committed
765
#' @examples
766
#' \donttest{
767
768
769
770
771
772
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
773
774
775
776
777
778
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
779
#' valids <- list(test = dtest)
780
781
782
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
783
#'   , nrounds = 5L
784
785
#'   , valids = valids
#' )
786
#' preds <- predict(model, test$data)
787
788
#'
#' # pass other prediction parameters
789
#' preds <- predict(
790
791
792
793
794
795
#'     model,
#'     test$data,
#'     params = list(
#'         predict_disable_shape_check = TRUE
#'    )
#' )
796
#' }
797
#' @importFrom utils modifyList
Guolin Ke's avatar
Guolin Ke committed
798
#' @export
James Lamb's avatar
James Lamb committed
799
800
predict.lgb.Booster <- function(object,
                                data,
801
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
802
803
804
805
806
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
807
                                reshape = FALSE,
808
                                params = list(),
James Lamb's avatar
James Lamb committed
809
                                ...) {
810

811
  if (!lgb.is.Booster(x = object)) {
812
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
813
  }
814

815
816
817
818
819
  additional_params <- list(...)
  if (length(additional_params) > 0L) {
    warning(paste0(
      "predict.lgb.Booster: Found the following passed through '...': "
      , paste(names(additional_params), collapse = ", ")
820
      , ". These are ignored. Use argument 'params' instead."
821
822
823
    ))
  }

824
825
826
  return(
    object$predict(
      data = data
827
828
829
830
831
832
833
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
      , reshape = reshape
834
      , params = params
835
    )
836
  )
Guolin Ke's avatar
Guolin Ke committed
837
838
}

839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `x`, returned as invisible.
#' @export
print.lgb.Booster <- function(x, ...) {
  # nolint start
  handle <- x$.__enclos_env__$private$handle
  handle_is_null <- lgb.is.null.handle(handle)

  if (!handle_is_null) {
    ntrees <- x$current_iter()
    if (ntrees == 1L) {
      cat("LightGBM Model (1 tree)\n")
    } else {
      cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
    }
  } else {
    cat("LightGBM Model\n")
  }

  if (!handle_is_null) {
    obj <- x$params$objective
    if (obj == "none") {
      obj <- "custom"
    }
    if (x$.__enclos_env__$private$num_class == 1L) {
      cat(sprintf("Objective: %s\n", obj))
    } else {
      cat(sprintf("Objective: %s (%d classes)\n"
          , obj
          , x$.__enclos_env__$private$num_class))
    }
  } else {
    cat("(Booster handle is invalid)\n")
  }

  if (!handle_is_null) {
    ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
    cat(sprintf("Fitted to dataset with %d columns\n", ncols))
  }
  # nolint end

  return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `object`, returned as invisible.
#' @export
summary.lgb.Booster <- function(object, ...) {
  print(object)
}

898
899
#' @name lgb.load
#' @title Load LightGBM model
900
901
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
902
#' @param filename path of model file
903
#' @param model_str a str containing the model (as a `character` or `raw` vector)
904
#'
905
#' @return lgb.Booster
906
#'
Guolin Ke's avatar
Guolin Ke committed
907
#' @examples
908
#' \donttest{
909
910
911
912
913
914
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
915
916
917
918
919
920
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
921
#' valids <- list(test = dtest)
922
923
924
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
925
#'   , nrounds = 5L
926
#'   , valids = valids
927
#'   , early_stopping_rounds = 3L
928
#' )
929
930
931
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
932
933
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
934
#' }
Guolin Ke's avatar
Guolin Ke committed
935
#' @export
936
lgb.load <- function(filename = NULL, model_str = NULL) {
937

938
939
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
940

941
942
943
944
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
945
    filename <- path.expand(filename)
946
947
948
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
949
950
    return(invisible(Booster$new(modelfile = filename)))
  }
951

952
  if (model_str_provided) {
953
954
    if (!is.raw(model_str) && !is.character(model_str)) {
      stop("lgb.load: model_str should be a character/raw vector")
955
    }
956
957
    return(invisible(Booster$new(model_str = model_str)))
  }
958

959
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
960
961
}

962
963
964
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
965
966
967
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
968
#'
969
#' @return lgb.Booster
970
#'
Guolin Ke's avatar
Guolin Ke committed
971
#' @examples
972
#' \donttest{
973
974
975
976
977
978
979
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
980
981
982
983
984
985
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
986
#' valids <- list(test = dtest)
987
988
989
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
990
#'   , nrounds = 10L
991
#'   , valids = valids
992
#'   , early_stopping_rounds = 5L
993
#' )
994
#' lgb.save(model, tempfile(fileext = ".txt"))
995
#' }
Guolin Ke's avatar
Guolin Ke committed
996
#' @export
997
lgb.save <- function(booster, filename, num_iteration = NULL) {
998

999
  if (!lgb.is.Booster(x = booster)) {
1000
1001
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1002

1003
1004
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
1005
  }
1006
  filename <- path.expand(filename)
1007

1008
  # Store booster
1009
1010
1011
1012
1013
1014
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
1015

Guolin Ke's avatar
Guolin Ke committed
1016
1017
}

1018
1019
1020
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
1021
1022
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1023
#'
Guolin Ke's avatar
Guolin Ke committed
1024
#' @return json format of model
1025
#'
Guolin Ke's avatar
Guolin Ke committed
1026
#' @examples
1027
#' \donttest{
1028
1029
1030
1031
1032
1033
1034
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1035
1036
1037
1038
1039
1040
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1041
#' valids <- list(test = dtest)
1042
1043
1044
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1045
#'   , nrounds = 10L
1046
#'   , valids = valids
1047
#'   , early_stopping_rounds = 5L
1048
#' )
1049
#' json_model <- lgb.dump(model)
1050
#' }
Guolin Ke's avatar
Guolin Ke committed
1051
#' @export
1052
lgb.dump <- function(booster, num_iteration = NULL) {
1053

1054
  if (!lgb.is.Booster(x = booster)) {
1055
1056
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1057

1058
  # Return booster at requested iteration
1059
  return(booster$dump_model(num_iteration =  num_iteration))
1060

Guolin Ke's avatar
Guolin Ke committed
1061
1062
}

1063
1064
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
1065
1066
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
1067
#' @param booster Object of class \code{lgb.Booster}
1068
1069
1070
1071
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
1072
#' @param is_err TRUE will return evaluation error instead
1073
#'
1074
#' @return numeric vector of evaluation result
1075
#'
1076
#' @examples
1077
#' \donttest{
1078
#' # train a regression model
1079
1080
1081
1082
1083
1084
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1085
1086
1087
1088
1089
1090
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1091
#' valids <- list(test = dtest)
1092
1093
1094
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1095
#'   , nrounds = 5L
1096
1097
#'   , valids = valids
#' )
1098
1099
1100
1101
1102
1103
1104
1105
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1106
#' lgb.get.eval.result(model, "test", "l2")
1107
#' }
Guolin Ke's avatar
Guolin Ke committed
1108
#' @export
1109
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1110

1111
  if (!lgb.is.Booster(x = booster)) {
1112
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1113
  }
1114

1115
1116
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1117
  }
1118

1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1129
  }
1130

1131
  # Check if evaluation result is existing
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1143
1144
    stop("lgb.get.eval.result: wrong eval name")
  }
1145

1146
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1147

1148
  # Check if error is requested
1149
  if (is_err) {
1150
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1151
  }
1152

1153
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1154
1155
    return(as.numeric(result))
  }
1156

1157
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1158
  iters <- as.integer(iters)
1159
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1160
  iters <- iters - delta
1161

1162
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1163
}