lgb.Booster.R 35.8 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom R6 R6Class
2
#' @importFrom utils modifyList
James Lamb's avatar
James Lamb committed
3
Booster <- R6::R6Class(
4
  classname = "lgb.Booster",
5
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
6
  public = list(
7

8
    best_iter = -1L,
9
    best_score = NA_real_,
10
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
11
    record_evals = list(),
12

13
14
    # Finalize will free up the handles
    finalize = function() {
15
16
17
18
19
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
20
      return(invisible(NULL))
21
    },
22

23
24
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
25
26
                          train_set = NULL,
                          modelfile = NULL,
27
                          model_str = NULL) {
28

29
      handle <- NULL
30

31
      if (!is.null(train_set)) {
32

33
34
35
36
37
38
39
40
41
42
43
44
        if (!lgb.is.Dataset(train_set)) {
          stop("lgb.Booster: Can only use lgb.Dataset as training data")
        }
        train_set_handle <- train_set$.__enclos_env__$private$get_handle()
        params <- utils::modifyList(params, train_set$get_params())
        params_str <- lgb.params2str(params = params)
        # Store booster handle
        handle <- .Call(
          LGBM_BoosterCreate_R
          , train_set_handle
          , params_str
        )
45

46
47
48
49
50
        # Create private booster information
        private$train_set <- train_set
        private$train_set_version <- train_set$.__enclos_env__$private$version
        private$num_dataset <- 1L
        private$init_predictor <- train_set$.__enclos_env__$private$predictor
51

52
        if (!is.null(private$init_predictor)) {
53

54
55
56
57
58
          # Merge booster
          .Call(
            LGBM_BoosterMerge_R
            , handle
            , private$init_predictor$.__enclos_env__$private$handle
59
          )
60

61
        }
62

63
64
        # Check current iteration
        private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
65

66
      } else if (!is.null(modelfile)) {
67

68
69
70
71
        # Do we have a model file as character?
        if (!is.character(modelfile)) {
          stop("lgb.Booster: Can only use a string as model file path")
        }
72

73
        modelfile <- path.expand(modelfile)
74

75
76
77
78
79
        # Create booster from model
        handle <- .Call(
          LGBM_BoosterCreateFromModelfile_R
          , modelfile
        )
80

81
      } else if (!is.null(model_str)) {
82

83
84
85
86
        # Do we have a model_str as character/raw?
        if (!is.raw(model_str) && !is.character(model_str)) {
          stop("lgb.Booster: Can only use a character/raw vector as model_str")
        }
87

88
89
90
91
92
        # Create booster from model
        handle <- .Call(
          LGBM_BoosterLoadModelFromString_R
          , model_str
        )
93

Guolin Ke's avatar
Guolin Ke committed
94
      } else {
95

96
97
98
99
        # Booster non existent
        stop(
          "lgb.Booster: Need at least either training dataset, "
          , "model file, or model_str to create booster instance"
100
        )
101

Guolin Ke's avatar
Guolin Ke committed
102
      }
103

104
105
106
107
108
109
110
111
112
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
      private$num_class <- 1L
      .Call(
        LGBM_BoosterGetNumClasses_R
        , private$handle
        , private$num_class
      )

113
114
      self$params <- params

115
116
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
117
    },
118

119
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
120
    set_train_data_name = function(name) {
121

122
      # Set name
Guolin Ke's avatar
Guolin Ke committed
123
      private$name_train_set <- name
124
      return(invisible(self))
125

Guolin Ke's avatar
Guolin Ke committed
126
    },
127

128
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
129
    add_valid = function(data, name) {
130

131
      if (!lgb.is.Dataset(data)) {
132
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
133
      }
134

Guolin Ke's avatar
Guolin Ke committed
135
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
136
137
138
139
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
140
      }
141

142
143
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
144
      }
145

146
      # Add validation data to booster
147
148
      .Call(
        LGBM_BoosterAddValidData_R
149
150
151
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
152

153
154
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
155
      private$num_dataset <- private$num_dataset + 1L
156
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
157

158
      return(invisible(self))
159

Guolin Ke's avatar
Guolin Ke committed
160
    },
161

162
    reset_parameter = function(params) {
163

164
      if (methods::is(self$params, "list")) {
165
        params <- utils::modifyList(self$params, params)
166
167
      }

168
      params_str <- lgb.params2str(params = params)
169

170
171
      self$restore_handle()

172
173
      .Call(
        LGBM_BoosterResetParameter_R
174
175
176
        , private$handle
        , params_str
      )
177
      self$params <- params
178

179
      return(invisible(self))
180

Guolin Ke's avatar
Guolin Ke committed
181
    },
182

183
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
184
    update = function(train_set = NULL, fobj = NULL) {
185

186
187
188
189
190
191
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

Guolin Ke's avatar
Guolin Ke committed
192
      if (!is.null(train_set)) {
193

194
        if (!lgb.is.Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
195
196
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
197

Guolin Ke's avatar
Guolin Ke committed
198
        if (!identical(train_set$predictor, private$init_predictor)) {
199
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
200
        }
201

202
203
        .Call(
          LGBM_BoosterResetTrainingData_R
204
205
206
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
207

208
        private$train_set <- train_set
209
        private$train_set_version <- train_set$.__enclos_env__$private$version
210

Guolin Ke's avatar
Guolin Ke committed
211
      }
212

213
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
214
      if (is.null(fobj)) {
215
216
217
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
218
        # Boost iteration from known objective
219
220
        .Call(
          LGBM_BoosterUpdateOneIter_R
221
222
          , private$handle
        )
223

Guolin Ke's avatar
Guolin Ke committed
224
      } else {
225

226
227
228
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
229
        if (!private$set_objective_to_none) {
230
          self$reset_parameter(params = list(objective = "none"))
231
          private$set_objective_to_none <- TRUE
232
        }
233
        # Perform objective calculation
234
235
        preds <- private$inner_predict(1L)
        gpair <- fobj(preds, private$train_set)
236

237
        # Check for gradient and hessian as list
238
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
239
          stop("lgb.Booster.update: custom objective should
240
241
            return a list with attributes (hess, grad)")
        }
242

243
244
245
246
247
248
249
250
251
252
253
        # Check grad and hess have the right shape
        n_grad <- length(gpair$grad)
        n_hess <- length(gpair$hess)
        n_preds <- length(preds)
        if (n_grad != n_preds) {
          stop(sprintf("Expected custom objective function to return grad with length %d, got %d.", n_preds, n_grad))
        }
        if (n_hess != n_preds) {
          stop(sprintf("Expected custom objective function to return hess with length %d, got %d.", n_preds, n_hess))
        }

254
        # Return custom boosting gradient/hessian
255
256
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
257
258
259
          , private$handle
          , gpair$grad
          , gpair$hess
260
          , n_preds
261
        )
262

Guolin Ke's avatar
Guolin Ke committed
263
      }
264

265
      # Loop through each iteration
266
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
267
268
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
269

270
      return(invisible(self))
271

Guolin Ke's avatar
Guolin Ke committed
272
    },
273

274
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
275
    rollback_one_iter = function() {
276

277
278
      self$restore_handle()

279
280
      .Call(
        LGBM_BoosterRollbackOneIter_R
281
282
        , private$handle
      )
283

284
      # Loop through each iteration
285
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
286
287
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
288

289
      return(invisible(self))
290

Guolin Ke's avatar
Guolin Ke committed
291
    },
292

293
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
294
    current_iter = function() {
295

296
297
      self$restore_handle()

298
      cur_iter <- 0L
299
300
301
302
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
303
      )
304
      return(cur_iter)
305

Guolin Ke's avatar
Guolin Ke committed
306
    },
307

308
    # Get upper bound
309
    upper_bound = function() {
310

311
312
      self$restore_handle()

313
      upper_bound <- 0.0
314
315
316
317
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
318
      )
319
      return(upper_bound)
320
321
322
323

    },

    # Get lower bound
324
    lower_bound = function() {
325

326
327
      self$restore_handle()

328
      lower_bound <- 0.0
329
330
331
332
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
333
      )
334
      return(lower_bound)
335
336
337

    },

338
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
339
    eval = function(data, name, feval = NULL) {
340

341
      if (!lgb.is.Dataset(data)) {
342
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
343
      }
344

345
      # Check for identical data
346
      data_idx <- 0L
347
      if (identical(data, private$train_set)) {
348
        data_idx <- 1L
349
      } else {
350

351
        # Check for validation data
352
        if (length(private$valid_sets) > 0L) {
353

354
          for (i in seq_along(private$valid_sets)) {
355

356
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
357
            if (identical(data, private$valid_sets[[i]])) {
358

359
              # Found identical data, skip
360
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
361
              break
362

Guolin Ke's avatar
Guolin Ke committed
363
            }
364

Guolin Ke's avatar
Guolin Ke committed
365
          }
366

Guolin Ke's avatar
Guolin Ke committed
367
        }
368

Guolin Ke's avatar
Guolin Ke committed
369
      }
370

371
      # Check if evaluation was not done
372
      if (data_idx == 0L) {
373

374
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
375
376
        self$add_valid(data, name)
        data_idx <- private$num_dataset
377

Guolin Ke's avatar
Guolin Ke committed
378
      }
379

380
      # Evaluate data
381
382
383
384
385
386
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
387
      )
388

Guolin Ke's avatar
Guolin Ke committed
389
    },
390

391
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
392
    eval_train = function(feval = NULL) {
393
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
394
    },
395

396
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
397
    eval_valid = function(feval = NULL) {
398

399
      ret <- list()
400

401
      if (length(private$valid_sets) <= 0L) {
402
403
        return(ret)
      }
404

405
      for (i in seq_along(private$valid_sets)) {
406
407
        ret <- append(
          x = ret
408
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
409
        )
Guolin Ke's avatar
Guolin Ke committed
410
      }
411

412
      return(ret)
413

Guolin Ke's avatar
Guolin Ke committed
414
    },
415

416
    # Save model
417
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
418

419
420
      self$restore_handle()

421
422
423
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
424

425
426
      filename <- path.expand(filename)

427
428
      .Call(
        LGBM_BoosterSaveModel_R
429
430
        , private$handle
        , as.integer(num_iteration)
431
        , as.integer(feature_importance_type)
432
        , filename
433
      )
434

435
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
436
    },
437

438
439
440
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {

      self$restore_handle()
441

442
443
444
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
445

446
      model_str <- .Call(
447
          LGBM_BoosterSaveModelToString_R
448
449
450
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
451
452
      )

453
454
455
456
      if (as_char) {
        model_str <- rawToChar(model_str)
      }

457
      return(model_str)
458

459
    },
460

461
    # Dump model in memory
462
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
463

464
465
      self$restore_handle()

466
467
468
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
469

470
      model_str <- .Call(
471
472
473
474
475
476
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
      )

477
      return(model_str)
478

Guolin Ke's avatar
Guolin Ke committed
479
    },
480

481
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
482
    predict = function(data,
483
                       start_iteration = NULL,
484
485
486
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
487
                       predcontrib = FALSE,
488
                       header = FALSE,
489
                       params = list()) {
490

491
492
      self$restore_handle()

493
494
495
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
496

497
498
499
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
      # possibly override keyword arguments with parameters
      #
      # NOTE: this length() check minimizes the latency introduced by these checks,
      #       for the common case where params is empty
      #
      # NOTE: doing this here instead of in Predictor$predict() to keep
      #       Predictor$predict() as fast as possible
      if (length(params) > 0L) {
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_raw_score"
          , params = params
          , alternative_kwarg_value = rawscore
        )
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_leaf_index"
          , params = params
          , alternative_kwarg_value = predleaf
        )
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_contrib"
          , params = params
          , alternative_kwarg_value = predcontrib
        )
        rawscore <- params[["predict_raw_score"]]
        predleaf <- params[["predict_leaf_index"]]
        predcontrib <- params[["predict_contrib"]]
      }

529
      # Predict on new data
530
531
532
533
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
      )
534
535
      return(
        predictor$predict(
536
537
538
539
540
541
542
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
543
        )
544
      )
545

546
    },
547

548
549
    # Transform into predictor
    to_predictor = function() {
550
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
551
    },
552

553
554
    # Used for serialization
    raw = NULL,
555

556
557
558
559
560
561
    # Store serialized raw bytes in model object
    save_raw = function() {
      if (is.null(self$raw)) {
        self$raw <- self$save_model_to_string(NULL, as_char = FALSE)
      }
      return(invisible(NULL))
562

563
    },
564

565
566
    drop_raw = function() {
      self$raw <- NULL
567
      return(invisible(NULL))
568
    },
569

570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
    check_null_handle = function() {
      return(lgb.is.null.handle(private$handle))
    },

    restore_handle = function() {
      if (self$check_null_handle()) {
        if (is.null(self$raw)) {
          .Call(LGBM_NullBoosterHandleError_R)
        }
        private$handle <- .Call(LGBM_BoosterLoadModelFromString_R, self$raw)
      }
      return(invisible(NULL))
    },

    get_handle = function() {
      return(private$handle)
586
    }
587

Guolin Ke's avatar
Guolin Ke committed
588
589
  ),
  private = list(
590
591
592
593
594
595
596
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
597
598
    num_class = 1L,
    num_dataset = 0L,
599
600
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
601
    higher_better_inner_eval = NULL,
602
    set_objective_to_none = FALSE,
603
    train_set_version = 0L,
604
605
    # Predict data
    inner_predict = function(idx) {
606

607
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
608
      data_name <- private$name_train_set
609

610
611
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
612
      }
613

614
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
615
616
617
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
618

619
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
620
      if (is.null(private$predict_buffer[[data_name]])) {
621

622
        # Store predictions
623
        npred <- 0L
624
625
        .Call(
          LGBM_BoosterGetNumPredict_R
626
          , private$handle
627
          , as.integer(idx - 1L)
628
          , npred
629
        )
630
        private$predict_buffer[[data_name]] <- numeric(npred)
631

Guolin Ke's avatar
Guolin Ke committed
632
      }
633

634
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
635
      if (!private$is_predicted_cur_iter[[idx]]) {
636

637
        # Use buffer
638
639
        .Call(
          LGBM_BoosterGetPredict_R
640
          , private$handle
641
          , as.integer(idx - 1L)
642
          , private$predict_buffer[[data_name]]
643
        )
Guolin Ke's avatar
Guolin Ke committed
644
645
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
646

647
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
648
    },
649

650
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
651
    get_eval_info = function() {
652

Guolin Ke's avatar
Guolin Ke committed
653
      if (is.null(private$eval_names)) {
654
        eval_names <- .Call(
655
          LGBM_BoosterGetEvalNames_R
656
657
          , private$handle
        )
658

659
        if (length(eval_names) > 0L) {
660

661
          # Parse and store privately names
662
          private$eval_names <- eval_names
663
664
665

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
666
          metric_names <- gsub("@.*", "", eval_names)
667
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
668

Guolin Ke's avatar
Guolin Ke committed
669
        }
670

Guolin Ke's avatar
Guolin Ke committed
671
      }
672

673
      return(private$eval_names)
674

Guolin Ke's avatar
Guolin Ke committed
675
    },
676

Guolin Ke's avatar
Guolin Ke committed
677
    inner_eval = function(data_name, data_idx, feval = NULL) {
678

679
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
680
681
682
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
683

684
685
      self$restore_handle()

Guolin Ke's avatar
Guolin Ke committed
686
      private$get_eval_info()
687

Guolin Ke's avatar
Guolin Ke committed
688
      ret <- list()
689

690
      if (length(private$eval_names) > 0L) {
691

692
693
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
694
695
        .Call(
          LGBM_BoosterGetEval_R
696
          , private$handle
697
          , as.integer(data_idx - 1L)
698
          , tmp_vals
699
        )
700

701
        for (i in seq_along(private$eval_names)) {
702

703
704
705
706
707
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
708
          res$higher_better <- private$higher_better_inner_eval[i]
709
          ret <- append(ret, list(res))
710

Guolin Ke's avatar
Guolin Ke committed
711
        }
712

Guolin Ke's avatar
Guolin Ke committed
713
      }
714

715
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
716
      if (!is.null(feval)) {
717

718
        # Check if evaluation metric is a function
719
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
720
721
          stop("lgb.Booster.eval: feval should be a function")
        }
722

Guolin Ke's avatar
Guolin Ke committed
723
        data <- private$train_set
724

725
        # Check if data to assess is existing differently
726
727
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
728
        }
729

730
        # Perform function evaluation
731
        res <- feval(private$inner_predict(data_idx), data)
732

733
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
734
735
736
          stop(
            "lgb.Booster.eval: custom eval function should return a list with attribute (name, value, higher_better)"
          )
737
        }
738

739
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
740
        res$data_name <- data_name
741
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
742
      }
743

744
      return(ret)
745

Guolin Ke's avatar
Guolin Ke committed
746
    }
747

Guolin Ke's avatar
Guolin Ke committed
748
749
750
  )
)

751
752
753
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
754
#' @param object Object of class \code{lgb.Booster}
755
756
#' @param newdata a \code{matrix} object, a \code{dgCMatrix} object or
#'                a character representing a path to a text file (CSV, TSV, or LibSVM)
757
758
759
760
761
762
763
764
765
766
767
768
769
#' @param type Type of prediction to output. Allowed types are:\itemize{
#'             \item \code{"response"}: will output the predicted score according to the objective function being
#'                   optimized (depending on the link function that the objective uses), after applying any necessary
#'                   transformations - for example, for \code{objective="binary"}, it will output class probabilities.
#'             \item \code{"class"}: for classification objectives, will output the class with the highest predicted
#'                   probability. For other objectives, will output the same as "response".
#'             \item \code{"raw"}: will output the non-transformed numbers (sum of predictions from boosting iterations'
#'                   results) from which the "response" number is produced for a given objective function - for example,
#'                   for \code{objective="binary"}, this corresponds to log-odds. For many objectives such as
#'                   "regression", since no transformation is applied, the output will be the same as for "response".
#'             \item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls
#'                   in each tree in the model, outputted as integers, with one column per tree.
#'             \item \code{"contrib"}: will return the per-feature contributions for each prediction, including an
770
#'                   intercept (each feature will produce one column).
771
772
773
774
#'             }
#'
#'             Note that, if using custom objectives, types "class" and "response" will not be available and will
#'             default towards using "raw" instead.
775
776
777
778
779
780
781
782
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
Guolin Ke's avatar
Guolin Ke committed
783
#' @param header only used for prediction for text file. True if text file has header
784
785
786
#' @param params a list of additional named parameters. See
#'               \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#'               the "Predict Parameters" section of the documentation} for a list of parameters and
787
788
#'               valid values. Where these conflict with the values of keyword arguments to this function,
#'               the values in \code{params} take precedence.
789
#' @param ... ignored
790
#' @return For prediction types that are meant to always return one output per observation (e.g. when predicting
791
792
#'         \code{type="response"} or \code{type="raw"} on a binary classification or regression objective), will
#'         return a vector with one element per row in \code{newdata}.
793
#'
794
#'         For prediction types that are meant to return more than one output per observation (e.g. when predicting
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
#'         \code{type="response"} or \code{type="raw"} on a multi-class objective, or when predicting
#'         \code{type="leaf"}, regardless of objective), will return a matrix with one row per observation in
#'         \code{newdata} and one column per output.
#'
#'         For \code{type="leaf"} predictions, will return a matrix with one row per observation in \code{newdata}
#'         and one column per tree. Note that for multiclass objectives, LightGBM trains one tree per class at each
#'         boosting iteration. That means that, for example, for a multiclass model with 3 classes, the leaf
#'         predictions for the first class can be found in columns 1, 4, 7, 10, etc.
#'
#'         For \code{type="contrib"}, will return a matrix of SHAP values with one row per observation in
#'         \code{newdata} and columns corresponding to features. For regression, ranking, cross-entropy, and binary
#'         classification objectives, this matrix contains one column per feature plus a final column containing the
#'         Shapley base value. For multiclass objectives, this matrix will represent \code{num_classes} such matrices,
#'         in the order "feature contributions for first class, feature contributions for second class, feature
#'         contributions for third class, etc.".
810
#'
Guolin Ke's avatar
Guolin Ke committed
811
#' @examples
812
#' \donttest{
813
814
815
816
817
818
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
819
820
821
822
823
824
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
825
#' valids <- list(test = dtest)
826
827
828
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
829
#'   , nrounds = 5L
830
831
#'   , valids = valids
#' )
832
#' preds <- predict(model, test$data)
833
834
#'
#' # pass other prediction parameters
835
#' preds <- predict(
836
837
838
839
840
841
#'     model,
#'     test$data,
#'     params = list(
#'         predict_disable_shape_check = TRUE
#'    )
#' )
842
#' }
843
#' @importFrom utils modifyList
Guolin Ke's avatar
Guolin Ke committed
844
#' @export
James Lamb's avatar
James Lamb committed
845
predict.lgb.Booster <- function(object,
846
                                newdata,
847
                                type = "response",
848
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
849
850
                                num_iteration = NULL,
                                header = FALSE,
851
                                params = list(),
James Lamb's avatar
James Lamb committed
852
                                ...) {
853

854
  if (!lgb.is.Booster(x = object)) {
855
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
856
  }
857

858
859
  additional_params <- list(...)
  if (length(additional_params) > 0L) {
860
861
    additional_params_names <- names(additional_params)
    if ("reshape" %in% additional_params_names) {
862
863
      stop("'reshape' argument is no longer supported.")
    }
864
865
866
867
868
869
870
871
872
873
874
875
876
877

    old_args_for_type <- list(
      "rawscore" = "raw"
      , "predleaf" = "leaf"
      , "predcontrib" = "contrib"
    )
    for (arg in names(old_args_for_type)) {
      if (arg %in% additional_params_names) {
        stop(sprintf("Argument '%s' is no longer supported. Use type='%s' instead."
                     , arg
                     , old_args_for_type[[arg]]))
      }
    }

878
879
    warning(paste0(
      "predict.lgb.Booster: Found the following passed through '...': "
880
      , toString(names(additional_params))
881
      , ". These are ignored. Use argument 'params' instead."
882
883
884
    ))
  }

885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
  if (!is.null(object$params$objective) && object$params$objective == "none" && type %in% c("class", "response")) {
    warning("Prediction types 'class' and 'response' are not supported for custom objectives.")
    type <- "raw"
  }

  rawscore <- FALSE
  predleaf <- FALSE
  predcontrib <- FALSE
  if (type == "raw") {
    rawscore <- TRUE
  } else if (type == "leaf") {
    predleaf <- TRUE
  } else if (type == "contrib") {
    predcontrib <- TRUE
  }

  pred <- object$predict(
    data = newdata
    , start_iteration = start_iteration
    , num_iteration = num_iteration
    , rawscore = rawscore
    , predleaf =  predleaf
    , predcontrib =  predcontrib
    , header = header
    , params = params
910
  )
911
912
913
914
915
916
917
918
  if (type == "class") {
    if (object$params$objective == "binary") {
      pred <- as.integer(pred >= 0.5)
    } else if (object$params$objective %in% c("multiclass", "multiclassova")) {
      pred <- max.col(pred) - 1L
    }
  }
  return(pred)
Guolin Ke's avatar
Guolin Ke committed
919
920
}

921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `x`, returned as invisible.
#' @export
print.lgb.Booster <- function(x, ...) {
  # nolint start
  handle <- x$.__enclos_env__$private$handle
  handle_is_null <- lgb.is.null.handle(handle)

  if (!handle_is_null) {
    ntrees <- x$current_iter()
    if (ntrees == 1L) {
      cat("LightGBM Model (1 tree)\n")
    } else {
      cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
    }
  } else {
    cat("LightGBM Model\n")
  }

  if (!handle_is_null) {
    obj <- x$params$objective
    if (obj == "none") {
      obj <- "custom"
    }
949
950
    num_class <- x$.__enclos_env__$private$num_class
    if (num_class == 1L) {
951
952
953
954
      cat(sprintf("Objective: %s\n", obj))
    } else {
      cat(sprintf("Objective: %s (%d classes)\n"
          , obj
955
          , num_class))
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
    }
  } else {
    cat("(Booster handle is invalid)\n")
  }

  if (!handle_is_null) {
    ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
    cat(sprintf("Fitted to dataset with %d columns\n", ncols))
  }
  # nolint end

  return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `object`, returned as invisible.
#' @export
summary.lgb.Booster <- function(object, ...) {
  print(object)
}

981
982
#' @name lgb.load
#' @title Load LightGBM model
983
984
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
985
#' @param filename path of model file
986
#' @param model_str a str containing the model (as a `character` or `raw` vector)
987
#'
988
#' @return lgb.Booster
989
#'
Guolin Ke's avatar
Guolin Ke committed
990
#' @examples
991
#' \donttest{
992
993
994
995
996
997
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
998
999
1000
1001
1002
1003
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1004
#' valids <- list(test = dtest)
1005
1006
1007
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1008
#'   , nrounds = 5L
1009
#'   , valids = valids
1010
#'   , early_stopping_rounds = 3L
1011
#' )
1012
1013
1014
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
1015
1016
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
1017
#' }
Guolin Ke's avatar
Guolin Ke committed
1018
#' @export
1019
lgb.load <- function(filename = NULL, model_str = NULL) {
1020

1021
1022
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
1023

1024
1025
1026
1027
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
1028
    filename <- path.expand(filename)
1029
1030
1031
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
1032
1033
    return(invisible(Booster$new(modelfile = filename)))
  }
1034

1035
  if (model_str_provided) {
1036
1037
    if (!is.raw(model_str) && !is.character(model_str)) {
      stop("lgb.load: model_str should be a character/raw vector")
1038
    }
1039
1040
    return(invisible(Booster$new(model_str = model_str)))
  }
1041

1042
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
1043
1044
}

1045
1046
1047
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
1048
1049
1050
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1051
#'
1052
#' @return lgb.Booster
1053
#'
Guolin Ke's avatar
Guolin Ke committed
1054
#' @examples
1055
#' \donttest{
1056
1057
1058
1059
1060
1061
1062
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1063
1064
1065
1066
1067
1068
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1069
#' valids <- list(test = dtest)
1070
1071
1072
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1073
#'   , nrounds = 10L
1074
#'   , valids = valids
1075
#'   , early_stopping_rounds = 5L
1076
#' )
1077
#' lgb.save(model, tempfile(fileext = ".txt"))
1078
#' }
Guolin Ke's avatar
Guolin Ke committed
1079
#' @export
1080
lgb.save <- function(booster, filename, num_iteration = NULL) {
1081

1082
  if (!lgb.is.Booster(x = booster)) {
1083
1084
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1085

1086
1087
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
1088
  }
1089
  filename <- path.expand(filename)
1090

1091
  # Store booster
1092
1093
1094
1095
1096
1097
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
1098

Guolin Ke's avatar
Guolin Ke committed
1099
1100
}

1101
1102
1103
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
1104
1105
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1106
#'
Guolin Ke's avatar
Guolin Ke committed
1107
#' @return json format of model
1108
#'
Guolin Ke's avatar
Guolin Ke committed
1109
#' @examples
1110
#' \donttest{
1111
1112
1113
1114
1115
1116
1117
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1118
1119
1120
1121
1122
1123
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1124
#' valids <- list(test = dtest)
1125
1126
1127
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1128
#'   , nrounds = 10L
1129
#'   , valids = valids
1130
#'   , early_stopping_rounds = 5L
1131
#' )
1132
#' json_model <- lgb.dump(model)
1133
#' }
Guolin Ke's avatar
Guolin Ke committed
1134
#' @export
1135
lgb.dump <- function(booster, num_iteration = NULL) {
1136

1137
  if (!lgb.is.Booster(x = booster)) {
1138
1139
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1140

1141
  # Return booster at requested iteration
1142
  return(booster$dump_model(num_iteration =  num_iteration))
1143

Guolin Ke's avatar
Guolin Ke committed
1144
1145
}

1146
1147
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
1148
1149
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
1150
#' @param booster Object of class \code{lgb.Booster}
1151
1152
1153
1154
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
1155
#' @param is_err TRUE will return evaluation error instead
1156
#'
1157
#' @return numeric vector of evaluation result
1158
#'
1159
#' @examples
1160
#' \donttest{
1161
#' # train a regression model
1162
1163
1164
1165
1166
1167
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1168
1169
1170
1171
1172
1173
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1174
#' valids <- list(test = dtest)
1175
1176
1177
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1178
#'   , nrounds = 5L
1179
1180
#'   , valids = valids
#' )
1181
1182
1183
1184
1185
1186
1187
1188
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1189
#' lgb.get.eval.result(model, "test", "l2")
1190
#' }
Guolin Ke's avatar
Guolin Ke committed
1191
#' @export
1192
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1193

1194
  if (!lgb.is.Booster(x = booster)) {
1195
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1196
  }
1197

1198
1199
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1200
  }
1201

1202
1203
1204
1205
1206
1207
1208
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
1209
      , toString(data_names)
1210
1211
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1212
  }
1213

1214
  # Check if evaluation result is existing
1215
1216
1217
1218
1219
1220
1221
1222
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
1223
      , toString(eval_names)
1224
1225
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1226
1227
    stop("lgb.get.eval.result: wrong eval name")
  }
1228

1229
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1230

1231
  # Check if error is requested
1232
  if (is_err) {
1233
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1234
  }
1235

1236
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1237
1238
    return(as.numeric(result))
  }
1239

1240
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1241
  iters <- as.integer(iters)
1242
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1243
  iters <- iters - delta
1244

1245
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1246
}