lgb.Booster.R 48 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom R6 R6Class
2
#' @importFrom utils modifyList
James Lamb's avatar
James Lamb committed
3
Booster <- R6::R6Class(
4
  classname = "lgb.Booster",
5
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
6
  public = list(
7

8
    best_iter = -1L,
9
    best_score = NA_real_,
10
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
11
    record_evals = list(),
12
    data_processor = NULL,
13

14
15
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
16
17
                          train_set = NULL,
                          modelfile = NULL,
18
                          model_str = NULL) {
19

20
      handle <- NULL
21

22
      if (!is.null(train_set)) {
23

24
        if (!.is_Dataset(train_set)) {
25
26
27
28
          stop("lgb.Booster: Can only use lgb.Dataset as training data")
        }
        train_set_handle <- train_set$.__enclos_env__$private$get_handle()
        params <- utils::modifyList(params, train_set$get_params())
29
        params_str <- .params2str(params = params)
30
31
32
33
34
35
        # Store booster handle
        handle <- .Call(
          LGBM_BoosterCreate_R
          , train_set_handle
          , params_str
        )
36

37
38
39
40
41
        # Create private booster information
        private$train_set <- train_set
        private$train_set_version <- train_set$.__enclos_env__$private$version
        private$num_dataset <- 1L
        private$init_predictor <- train_set$.__enclos_env__$private$predictor
42

43
        if (!is.null(private$init_predictor)) {
44

45
46
47
48
49
          # Merge booster
          .Call(
            LGBM_BoosterMerge_R
            , handle
            , private$init_predictor$.__enclos_env__$private$handle
50
          )
51

52
        }
53

54
55
        # Check current iteration
        private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
56

57
      } else if (!is.null(modelfile)) {
58

59
60
61
62
        # Do we have a model file as character?
        if (!is.character(modelfile)) {
          stop("lgb.Booster: Can only use a string as model file path")
        }
63

64
        modelfile <- path.expand(modelfile)
65

66
67
68
69
70
        # Create booster from model
        handle <- .Call(
          LGBM_BoosterCreateFromModelfile_R
          , modelfile
        )
71
        params <- private$get_loaded_param(handle)
72

73
      } else if (!is.null(model_str)) {
74

75
76
77
78
        # Do we have a model_str as character/raw?
        if (!is.raw(model_str) && !is.character(model_str)) {
          stop("lgb.Booster: Can only use a character/raw vector as model_str")
        }
79

80
81
82
83
84
        # Create booster from model
        handle <- .Call(
          LGBM_BoosterLoadModelFromString_R
          , model_str
        )
85

Guolin Ke's avatar
Guolin Ke committed
86
      } else {
87

88
89
90
91
        # Booster non existent
        stop(
          "lgb.Booster: Need at least either training dataset, "
          , "model file, or model_str to create booster instance"
92
        )
93

Guolin Ke's avatar
Guolin Ke committed
94
      }
95

96
97
98
99
100
101
102
103
104
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
      private$num_class <- 1L
      .Call(
        LGBM_BoosterGetNumClasses_R
        , private$handle
        , private$num_class
      )

105
106
      self$params <- params

107
108
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
109
    },
110

111
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
112
    set_train_data_name = function(name) {
113

114
      # Set name
Guolin Ke's avatar
Guolin Ke committed
115
      private$name_train_set <- name
116
      return(invisible(self))
117

Guolin Ke's avatar
Guolin Ke committed
118
    },
119

120
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
121
    add_valid = function(data, name) {
122

123
      if (!.is_Dataset(data)) {
124
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
125
      }
126

Guolin Ke's avatar
Guolin Ke committed
127
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
128
129
130
131
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
132
      }
133

134
135
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
136
      }
137

138
      # Add validation data to booster
139
140
      .Call(
        LGBM_BoosterAddValidData_R
141
142
143
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
144

145
146
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
147
      private$num_dataset <- private$num_dataset + 1L
148
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
149

150
      return(invisible(self))
151

Guolin Ke's avatar
Guolin Ke committed
152
    },
153

154
    reset_parameter = function(params) {
155

156
      if (methods::is(self$params, "list")) {
157
        params <- utils::modifyList(self$params, params)
158
159
      }

160
      params_str <- .params2str(params = params)
161

162
163
      self$restore_handle()

164
165
      .Call(
        LGBM_BoosterResetParameter_R
166
167
168
        , private$handle
        , params_str
      )
169
      self$params <- params
170

171
      return(invisible(self))
172

Guolin Ke's avatar
Guolin Ke committed
173
    },
174

175
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
176
    update = function(train_set = NULL, fobj = NULL) {
177

178
179
180
181
182
183
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

Guolin Ke's avatar
Guolin Ke committed
184
      if (!is.null(train_set)) {
185

186
        if (!.is_Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
187
188
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
189

Guolin Ke's avatar
Guolin Ke committed
190
        if (!identical(train_set$predictor, private$init_predictor)) {
191
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
192
        }
193

194
195
        .Call(
          LGBM_BoosterResetTrainingData_R
196
197
198
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
199

200
        private$train_set <- train_set
201
        private$train_set_version <- train_set$.__enclos_env__$private$version
202

Guolin Ke's avatar
Guolin Ke committed
203
      }
204

205
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
206
      if (is.null(fobj)) {
207
208
209
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
210
        # Boost iteration from known objective
211
212
        .Call(
          LGBM_BoosterUpdateOneIter_R
213
214
          , private$handle
        )
215

Guolin Ke's avatar
Guolin Ke committed
216
      } else {
217

218
219
220
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
221
        if (!private$set_objective_to_none) {
222
          self$reset_parameter(params = list(objective = "none"))
223
          private$set_objective_to_none <- TRUE
224
        }
225
        # Perform objective calculation
226
227
        preds <- private$inner_predict(1L)
        gpair <- fobj(preds, private$train_set)
228

229
        # Check for gradient and hessian as list
230
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
231
          stop("lgb.Booster.update: custom objective should
232
233
            return a list with attributes (hess, grad)")
        }
234

235
236
237
238
239
240
241
242
243
244
245
        # Check grad and hess have the right shape
        n_grad <- length(gpair$grad)
        n_hess <- length(gpair$hess)
        n_preds <- length(preds)
        if (n_grad != n_preds) {
          stop(sprintf("Expected custom objective function to return grad with length %d, got %d.", n_preds, n_grad))
        }
        if (n_hess != n_preds) {
          stop(sprintf("Expected custom objective function to return hess with length %d, got %d.", n_preds, n_hess))
        }

246
        # Return custom boosting gradient/hessian
247
248
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
249
250
251
          , private$handle
          , gpair$grad
          , gpair$hess
252
          , n_preds
253
        )
254

Guolin Ke's avatar
Guolin Ke committed
255
      }
256

257
      # Loop through each iteration
258
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
259
260
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
261

262
      return(invisible(self))
263

Guolin Ke's avatar
Guolin Ke committed
264
    },
265

266
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
267
    rollback_one_iter = function() {
268

269
270
      self$restore_handle()

271
272
      .Call(
        LGBM_BoosterRollbackOneIter_R
273
274
        , private$handle
      )
275

276
      # Loop through each iteration
277
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
278
279
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
280

281
      return(invisible(self))
282

Guolin Ke's avatar
Guolin Ke committed
283
    },
284

285
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
286
    current_iter = function() {
287

288
289
      self$restore_handle()

290
      cur_iter <- 0L
291
292
293
294
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
295
      )
296
      return(cur_iter)
297

Guolin Ke's avatar
Guolin Ke committed
298
    },
299

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    # Number of trees per iteration
    num_trees_per_iter = function() {

      self$restore_handle()

      trees_per_iter <- 1L
      .Call(
        LGBM_BoosterNumModelPerIteration_R
        , private$handle
        , trees_per_iter
      )
      return(trees_per_iter)

    },

    # Total number of trees
    num_trees = function() {

      self$restore_handle()

      ntrees <- 0L
      .Call(
        LGBM_BoosterNumberOfTotalModel_R
        , private$handle
        , ntrees
      )
      return(ntrees)

    },

    # Number of iterations (= rounds)
    num_iter = function() {

      ntrees <- self$num_trees()
      trees_per_iter <- self$num_trees_per_iter()

      return(ntrees / trees_per_iter)

    },

340
    # Get upper bound
341
    upper_bound = function() {
342

343
344
      self$restore_handle()

345
      upper_bound <- 0.0
346
347
348
349
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
350
      )
351
      return(upper_bound)
352
353
354
355

    },

    # Get lower bound
356
    lower_bound = function() {
357

358
359
      self$restore_handle()

360
      lower_bound <- 0.0
361
362
363
364
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
365
      )
366
      return(lower_bound)
367
368
369

    },

370
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
371
    eval = function(data, name, feval = NULL) {
372

373
      if (!.is_Dataset(data)) {
374
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
375
      }
376

377
      # Check for identical data
378
      data_idx <- 0L
379
      if (identical(data, private$train_set)) {
380
        data_idx <- 1L
381
      } else {
382

383
        # Check for validation data
384
        if (length(private$valid_sets) > 0L) {
385

386
          for (i in seq_along(private$valid_sets)) {
387

388
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
389
            if (identical(data, private$valid_sets[[i]])) {
390

391
              # Found identical data, skip
392
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
393
              break
394

Guolin Ke's avatar
Guolin Ke committed
395
            }
396

Guolin Ke's avatar
Guolin Ke committed
397
          }
398

Guolin Ke's avatar
Guolin Ke committed
399
        }
400

Guolin Ke's avatar
Guolin Ke committed
401
      }
402

403
      # Check if evaluation was not done
404
      if (data_idx == 0L) {
405

406
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
407
408
        self$add_valid(data, name)
        data_idx <- private$num_dataset
409

Guolin Ke's avatar
Guolin Ke committed
410
      }
411

412
      # Evaluate data
413
414
415
416
417
418
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
419
      )
420

Guolin Ke's avatar
Guolin Ke committed
421
    },
422

423
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
424
    eval_train = function(feval = NULL) {
425
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
426
    },
427

428
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
429
    eval_valid = function(feval = NULL) {
430

431
      ret <- list()
432

433
      if (length(private$valid_sets) <= 0L) {
434
435
        return(ret)
      }
436

437
      for (i in seq_along(private$valid_sets)) {
438
439
        ret <- append(
          x = ret
440
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
441
        )
Guolin Ke's avatar
Guolin Ke committed
442
      }
443

444
      return(ret)
445

Guolin Ke's avatar
Guolin Ke committed
446
    },
447

448
    # Save model
449
450
451
452
453
454
    save_model = function(
      filename
      , num_iteration = NULL
      , feature_importance_type = 0L
      , start_iteration = 1L
    ) {
455

456
457
      self$restore_handle()

458
459
460
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
461

462
463
      filename <- path.expand(filename)

464
465
      .Call(
        LGBM_BoosterSaveModel_R
466
467
        , private$handle
        , as.integer(num_iteration)
468
        , as.integer(feature_importance_type)
469
        , filename
470
        , as.integer(start_iteration) - 1L  # Turn to 0-based
471
      )
472

473
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
474
    },
475

476
477
478
479
480
481
    save_model_to_string = function(
      num_iteration = NULL
      , feature_importance_type = 0L
      , as_char = TRUE
      , start_iteration = 1L
    ) {
482
483

      self$restore_handle()
484

485
486
487
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
488

489
      model_str <- .Call(
490
          LGBM_BoosterSaveModelToString_R
491
492
493
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
494
          , as.integer(start_iteration) - 1L  # Turn to 0-based
495
496
      )

497
498
499
500
      if (as_char) {
        model_str <- rawToChar(model_str)
      }

501
      return(model_str)
502

503
    },
504

505
    # Dump model in memory
506
507
508
    dump_model = function(
      num_iteration = NULL, feature_importance_type = 0L, start_iteration = 1L
    ) {
509

510
511
      self$restore_handle()

512
513
514
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
515

516
      model_str <- .Call(
517
518
519
520
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
521
        , as.integer(start_iteration) - 1L  # Turn to 0-based
522
523
      )

524
      return(model_str)
525

Guolin Ke's avatar
Guolin Ke committed
526
    },
527

528
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
529
    predict = function(data,
530
                       start_iteration = NULL,
531
532
533
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
534
                       predcontrib = FALSE,
535
                       header = FALSE,
536
                       params = list()) {
537

538
539
      self$restore_handle()

540
541
542
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
543

544
545
546
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
547

548
549
550
551
552
553
554
555
      # possibly override keyword arguments with parameters
      #
      # NOTE: this length() check minimizes the latency introduced by these checks,
      #       for the common case where params is empty
      #
      # NOTE: doing this here instead of in Predictor$predict() to keep
      #       Predictor$predict() as fast as possible
      if (length(params) > 0L) {
556
        params <- .check_wrapper_param(
557
558
559
560
          main_param_name = "predict_raw_score"
          , params = params
          , alternative_kwarg_value = rawscore
        )
561
        params <- .check_wrapper_param(
562
563
564
565
          main_param_name = "predict_leaf_index"
          , params = params
          , alternative_kwarg_value = predleaf
        )
566
        params <- .check_wrapper_param(
567
568
569
570
571
572
573
574
575
          main_param_name = "predict_contrib"
          , params = params
          , alternative_kwarg_value = predcontrib
        )
        rawscore <- params[["predict_raw_score"]]
        predleaf <- params[["predict_leaf_index"]]
        predcontrib <- params[["predict_contrib"]]
      }

576
      # Predict on new data
577
578
579
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
580
        , fast_predict_config = private$fast_predict_config
581
      )
582
583
      return(
        predictor$predict(
584
585
586
587
588
589
590
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
591
        )
592
      )
593

594
    },
595

596
597
    # Transform into predictor
    to_predictor = function() {
598
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
599
    },
600

601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
    configure_fast_predict = function(csr = FALSE,
                                      start_iteration = NULL,
                                      num_iteration = NULL,
                                      rawscore = FALSE,
                                      predleaf = FALSE,
                                      predcontrib = FALSE,
                                      params = list()) {

      self$restore_handle()
      ncols <- .Call(LGBM_BoosterGetNumFeature_R, private$handle)

      if (is.null(num_iteration)) {
        num_iteration <- -1L
      }
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }

      if (!csr) {
        fun <- LGBM_BoosterPredictForMatSingleRowFastInit_R
      } else {
        fun <- LGBM_BoosterPredictForCSRSingleRowFastInit_R
      }

      fast_handle <- .Call(
        fun
        , private$handle
        , ncols
        , rawscore
        , predleaf
        , predcontrib
        , start_iteration
        , num_iteration
634
        , .params2str(params = params)
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
      )

      private$fast_predict_config <- list(
        handle = fast_handle
        , csr = as.logical(csr)
        , ncols = ncols
        , start_iteration = start_iteration
        , num_iteration = num_iteration
        , rawscore = as.logical(rawscore)
        , predleaf = as.logical(predleaf)
        , predcontrib = as.logical(predcontrib)
        , params = params
      )

      return(invisible(NULL))
    },

652
653
    # Used for serialization
    raw = NULL,
654

655
656
657
658
659
660
    # Store serialized raw bytes in model object
    save_raw = function() {
      if (is.null(self$raw)) {
        self$raw <- self$save_model_to_string(NULL, as_char = FALSE)
      }
      return(invisible(NULL))
661

662
    },
663

664
665
    drop_raw = function() {
      self$raw <- NULL
666
      return(invisible(NULL))
667
    },
668

669
    check_null_handle = function() {
670
      return(.is_null_handle(private$handle))
671
672
673
674
675
676
677
678
679
680
681
682
683
684
    },

    restore_handle = function() {
      if (self$check_null_handle()) {
        if (is.null(self$raw)) {
          .Call(LGBM_NullBoosterHandleError_R)
        }
        private$handle <- .Call(LGBM_BoosterLoadModelFromString_R, self$raw)
      }
      return(invisible(NULL))
    },

    get_handle = function() {
      return(private$handle)
685
    }
686

Guolin Ke's avatar
Guolin Ke committed
687
688
  ),
  private = list(
689
690
691
692
693
694
695
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
696
697
    num_class = 1L,
    num_dataset = 0L,
698
699
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
700
    higher_better_inner_eval = NULL,
701
    set_objective_to_none = FALSE,
702
    train_set_version = 0L,
703
    fast_predict_config = list(),
704
705
706
707
708
709
710
711
712
713
714

    # finalize() will free up the handles
    finalize = function() {
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
      return(invisible(NULL))
    },

715
716
    # Predict data
    inner_predict = function(idx) {
717

718
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
719
      data_name <- private$name_train_set
720

721
722
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
723
      }
724

725
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
726
727
728
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
729

730
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
731
      if (is.null(private$predict_buffer[[data_name]])) {
732

733
        # Store predictions
734
        npred <- 0L
735
736
        .Call(
          LGBM_BoosterGetNumPredict_R
737
          , private$handle
738
          , as.integer(idx - 1L)
739
          , npred
740
        )
741
        private$predict_buffer[[data_name]] <- numeric(npred)
742

Guolin Ke's avatar
Guolin Ke committed
743
      }
744

745
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
746
      if (!private$is_predicted_cur_iter[[idx]]) {
747

748
        # Use buffer
749
750
        .Call(
          LGBM_BoosterGetPredict_R
751
          , private$handle
752
          , as.integer(idx - 1L)
753
          , private$predict_buffer[[data_name]]
754
        )
Guolin Ke's avatar
Guolin Ke committed
755
756
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
757

758
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
759
    },
760

761
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
762
    get_eval_info = function() {
763

Guolin Ke's avatar
Guolin Ke committed
764
      if (is.null(private$eval_names)) {
765
        eval_names <- .Call(
766
          LGBM_BoosterGetEvalNames_R
767
768
          , private$handle
        )
769

770
        if (length(eval_names) > 0L) {
771

772
          # Parse and store privately names
773
          private$eval_names <- eval_names
774
775
776

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
777
          metric_names <- gsub("@.*", "", eval_names)
778
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
779

Guolin Ke's avatar
Guolin Ke committed
780
        }
781

Guolin Ke's avatar
Guolin Ke committed
782
      }
783

784
      return(private$eval_names)
785

Guolin Ke's avatar
Guolin Ke committed
786
    },
787

788
789
790
791
792
793
794
795
796
797
798
799
800
801
    get_loaded_param = function(handle) {
      params_str <- .Call(
        LGBM_BoosterGetLoadedParam_R
        , handle
      )
      params <- jsonlite::fromJSON(params_str)
      if ("interaction_constraints" %in% names(params)) {
        params[["interaction_constraints"]] <- lapply(params[["interaction_constraints"]], function(x) x + 1L)
      }

      return(params)

    },

Guolin Ke's avatar
Guolin Ke committed
802
    inner_eval = function(data_name, data_idx, feval = NULL) {
803

804
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
805
806
807
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
808

809
810
      self$restore_handle()

Guolin Ke's avatar
Guolin Ke committed
811
      private$get_eval_info()
812

Guolin Ke's avatar
Guolin Ke committed
813
      ret <- list()
814

815
      if (length(private$eval_names) > 0L) {
816

817
818
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
819
820
        .Call(
          LGBM_BoosterGetEval_R
821
          , private$handle
822
          , as.integer(data_idx - 1L)
823
          , tmp_vals
824
        )
825

826
        for (i in seq_along(private$eval_names)) {
827

828
829
830
831
832
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
833
          res$higher_better <- private$higher_better_inner_eval[i]
834
          ret <- append(ret, list(res))
835

Guolin Ke's avatar
Guolin Ke committed
836
        }
837

Guolin Ke's avatar
Guolin Ke committed
838
      }
839

840
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
841
      if (!is.null(feval)) {
842

843
        # Check if evaluation metric is a function
844
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
845
846
          stop("lgb.Booster.eval: feval should be a function")
        }
847

Guolin Ke's avatar
Guolin Ke committed
848
        data <- private$train_set
849

850
        # Check if data to assess is existing differently
851
852
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
853
        }
854

855
        # Perform function evaluation
856
        res <- feval(private$inner_predict(data_idx), data)
857

858
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
859
860
861
          stop(
            "lgb.Booster.eval: custom eval function should return a list with attribute (name, value, higher_better)"
          )
862
        }
863

864
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
865
        res$data_name <- data_name
866
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
867
      }
868

869
      return(ret)
870

Guolin Ke's avatar
Guolin Ke committed
871
    }
872

Guolin Ke's avatar
Guolin Ke committed
873
874
875
  )
)

876
#' @name lgb_predict_shared_params
877
#' @title Shared prediction parameter docs
878
879
880
881
882
#' @param type Type of prediction to output. Allowed types are:\itemize{
#'             \item \code{"response"}: will output the predicted score according to the objective function being
#'                   optimized (depending on the link function that the objective uses), after applying any necessary
#'                   transformations - for example, for \code{objective="binary"}, it will output class probabilities.
#'             \item \code{"class"}: for classification objectives, will output the class with the highest predicted
883
884
885
#'                   probability. For other objectives, will output the same as "response". Note that \code{"class"} is
#'                   not a supported type for \link{lgb.configure_fast_predict} (see the documentation of that function
#'                   for more details).
886
887
888
889
890
891
892
#'             \item \code{"raw"}: will output the non-transformed numbers (sum of predictions from boosting iterations'
#'                   results) from which the "response" number is produced for a given objective function - for example,
#'                   for \code{objective="binary"}, this corresponds to log-odds. For many objectives such as
#'                   "regression", since no transformation is applied, the output will be the same as for "response".
#'             \item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls
#'                   in each tree in the model, outputted as integers, with one column per tree.
#'             \item \code{"contrib"}: will return the per-feature contributions for each prediction, including an
893
#'                   intercept (each feature will produce one column).
894
895
896
897
#'             }
#'
#'             Note that, if using custom objectives, types "class" and "response" will not be available and will
#'             default towards using "raw" instead.
898
899
900
901
902
#'
#'             If the model was fit through function \link{lightgbm} and it was passed a factor as labels,
#'             passing the prediction type through \code{params} instead of through this argument might
#'             result in factor levels for classification objectives not being applied correctly to the
#'             resulting output.
903
904
905
#'
#'             \emph{New in version 4.0.0}
#'
906
907
908
909
910
911
912
913
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
914
915
916
#' @param params a list of additional named parameters. See
#'               \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#'               the "Predict Parameters" section of the documentation} for a list of parameters and
917
918
#'               valid values. Where these conflict with the values of keyword arguments to this function,
#'               the values in \code{params} take precedence.
919
920
#' @details This page contains shared documentation for prediction-related parameters used throughout the package.
#' @keywords internal
921
922
923
924
925
NULL

#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
926
927
928
#'
#'              \emph{New in version 4.0.0}
#'
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
#' @details If the model object has been configured for fast single-row predictions through
#'          \link{lgb.configure_fast_predict}, this function will use the prediction parameters
#'          that were configured for it - as such, extra prediction parameters should not be passed
#'          here, otherwise the configuration will be ignored and the slow route will be taken.
#' @inheritParams lgb_predict_shared_params
#' @param object Object of class \code{lgb.Booster}
#' @param newdata a \code{matrix} object, a \code{dgCMatrix}, a \code{dgRMatrix} object, a \code{dsparseVector} object,
#'                or a character representing a path to a text file (CSV, TSV, or LibSVM).
#'
#'                For sparse inputs, if predictions are only going to be made for a single row, it will be faster to
#'                use CSR format, in which case the data may be passed as either a single-row CSR matrix (class
#'                \code{dgRMatrix} from package \code{Matrix}) or as a sparse numeric vector (class
#'                \code{dsparseVector} from package \code{Matrix}).
#'
#'                If single-row predictions are going to be performed frequently, it is recommended to
#'                pre-configure the model object for fast single-row sparse predictions through function
#'                \link{lgb.configure_fast_predict}.
946
947
948
#'
#'                \emph{Changed from 'data', in version 4.0.0}
#'
949
#' @param header only used for prediction for text file. True if text file has header
950
#' @param ... ignored
951
#' @return For prediction types that are meant to always return one output per observation (e.g. when predicting
952
953
#'         \code{type="response"} or \code{type="raw"} on a binary classification or regression objective), will
#'         return a vector with one element per row in \code{newdata}.
954
#'
955
#'         For prediction types that are meant to return more than one output per observation (e.g. when predicting
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
#'         \code{type="response"} or \code{type="raw"} on a multi-class objective, or when predicting
#'         \code{type="leaf"}, regardless of objective), will return a matrix with one row per observation in
#'         \code{newdata} and one column per output.
#'
#'         For \code{type="leaf"} predictions, will return a matrix with one row per observation in \code{newdata}
#'         and one column per tree. Note that for multiclass objectives, LightGBM trains one tree per class at each
#'         boosting iteration. That means that, for example, for a multiclass model with 3 classes, the leaf
#'         predictions for the first class can be found in columns 1, 4, 7, 10, etc.
#'
#'         For \code{type="contrib"}, will return a matrix of SHAP values with one row per observation in
#'         \code{newdata} and columns corresponding to features. For regression, ranking, cross-entropy, and binary
#'         classification objectives, this matrix contains one column per feature plus a final column containing the
#'         Shapley base value. For multiclass objectives, this matrix will represent \code{num_classes} such matrices,
#'         in the order "feature contributions for first class, feature contributions for second class, feature
#'         contributions for third class, etc.".
971
#'
972
973
974
975
976
#'         If the model was fit through function \link{lightgbm} and it was passed a factor as labels, predictions
#'         returned from this function will retain the factor levels (either as values for \code{type="class"}, or
#'         as column names for \code{type="response"} and \code{type="raw"} for multi-class objectives). Note that
#'         passing the requested prediction type under \code{params} instead of through \code{type} might result in
#'         the factor levels not being present in the output.
Guolin Ke's avatar
Guolin Ke committed
977
#' @examples
978
#' \donttest{
979
980
#' \dontshow{setLGBMthreads(2L)}
#' \dontshow{data.table::setDTthreads(1L)}
981
982
983
984
985
986
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
987
988
989
990
991
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
992
#'   , num_threads = 2L
993
#' )
994
#' valids <- list(test = dtest)
995
996
997
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
998
#'   , nrounds = 5L
999
1000
#'   , valids = valids
#' )
1001
#' preds <- predict(model, test$data)
1002
1003
#'
#' # pass other prediction parameters
1004
#' preds <- predict(
1005
1006
1007
1008
1009
1010
#'     model,
#'     test$data,
#'     params = list(
#'         predict_disable_shape_check = TRUE
#'    )
#' )
1011
#' }
1012
#' @importFrom utils modifyList
Guolin Ke's avatar
Guolin Ke committed
1013
#' @export
James Lamb's avatar
James Lamb committed
1014
predict.lgb.Booster <- function(object,
1015
                                newdata,
1016
                                type = "response",
1017
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
1018
1019
                                num_iteration = NULL,
                                header = FALSE,
1020
                                params = list(),
James Lamb's avatar
James Lamb committed
1021
                                ...) {
1022

1023
  if (!.is_Booster(x = object)) {
1024
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster", q = FALSE))
Guolin Ke's avatar
Guolin Ke committed
1025
  }
1026

1027
1028
  additional_params <- list(...)
  if (length(additional_params) > 0L) {
1029
1030
    additional_params_names <- names(additional_params)
    if ("reshape" %in% additional_params_names) {
1031
1032
      stop("'reshape' argument is no longer supported.")
    }
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046

    old_args_for_type <- list(
      "rawscore" = "raw"
      , "predleaf" = "leaf"
      , "predcontrib" = "contrib"
    )
    for (arg in names(old_args_for_type)) {
      if (arg %in% additional_params_names) {
        stop(sprintf("Argument '%s' is no longer supported. Use type='%s' instead."
                     , arg
                     , old_args_for_type[[arg]]))
      }
    }

1047
1048
    warning(paste0(
      "predict.lgb.Booster: Found the following passed through '...': "
1049
      , toString(names(additional_params))
1050
      , ". These are ignored. Use argument 'params' instead."
1051
1052
1053
    ))
  }

1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
  if (!is.null(object$params$objective) && object$params$objective == "none" && type %in% c("class", "response")) {
    warning("Prediction types 'class' and 'response' are not supported for custom objectives.")
    type <- "raw"
  }

  rawscore <- FALSE
  predleaf <- FALSE
  predcontrib <- FALSE
  if (type == "raw") {
    rawscore <- TRUE
  } else if (type == "leaf") {
    predleaf <- TRUE
  } else if (type == "contrib") {
    predcontrib <- TRUE
  }

  pred <- object$predict(
    data = newdata
    , start_iteration = start_iteration
    , num_iteration = num_iteration
    , rawscore = rawscore
    , predleaf =  predleaf
    , predcontrib =  predcontrib
    , header = header
    , params = params
1079
  )
1080
  if (type == "class") {
1081
    if (object$params$objective %in% .BINARY_OBJECTIVES()) {
1082
      pred <- as.integer(pred >= 0.5)
1083
    } else if (object$params$objective %in% .MULTICLASS_OBJECTIVES()) {
1084
1085
1086
      pred <- max.col(pred) - 1L
    }
  }
1087
1088
1089
1090
1091
1092
  if (!is.null(object$data_processor)) {
    pred <- object$data_processor$process_predictions(
      pred = pred
      , type = type
    )
  }
1093
  return(pred)
Guolin Ke's avatar
Guolin Ke committed
1094
1095
}

1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
#' @title Configure Fast Single-Row Predictions
#' @description Pre-configures a LightGBM model object to produce fast single-row predictions
#'              for a given input data type, prediction type, and parameters.
#' @details Calling this function multiple times with different parameters might not override
#'          the previous configuration and might trigger undefined behavior.
#'
#'          Any saved configuration for fast predictions might be lost after making a single-row
#'          prediction of a different type than what was configured (except for types "response" and
#'          "class", which can be switched between each other at any time without losing the configuration).
#'
#'          In some situations, setting a fast prediction configuration for one type of prediction
#'          might cause the prediction function to keep using that configuration for single-row
#'          predictions even if the requested type of prediction is different from what was configured.
#'
#'          Note that this function will not accept argument \code{type="class"} - for such cases, one
#'          can pass \code{type="response"} to this function and then \code{type="class"} to the
#'          \code{predict} function - the fast configuration will not be lost or altered if the switch
#'          is between "response" and "class".
#'
#'          The configuration does not survive de-serializations, so it has to be generated
#'          anew in every R process that is going to use it (e.g. if loading a model object
#'          through \code{readRDS}, whatever configuration was there previously will be lost).
#'
#'          Requesting a different prediction type or passing parameters to \link{predict.lgb.Booster}
#'          will cause it to ignore the fast-predict configuration and take the slow route instead
1121
#'          (but be aware that an existing configuration might not always be overridden by supplying
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
#'          different parameters or prediction type, so make sure to check that the output is what
#'          was expected when a prediction is to be made on a single row for something different than
#'          what is configured).
#'
#'          Note that, if configuring a non-default prediction type (such as leaf indices),
#'          then that type must also be passed in the call to \link{predict.lgb.Booster} in
#'          order for it to use the configuration. This also applies for \code{start_iteration}
#'          and \code{num_iteration}, but \bold{the \code{params} list must be empty} in the call to \code{predict}.
#'
#'          Predictions about feature contributions do not allow a fast route for CSR inputs,
#'          and as such, this function will produce an error if passing \code{csr=TRUE} and
#'          \code{type = "contrib"} together.
#' @inheritParams lgb_predict_shared_params
1135
#' @param model LightGBM model object (class \code{lgb.Booster}).
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
#'
#'              \bold{The object will be modified in-place}.
#' @param csr Whether the prediction function is going to be called on sparse CSR inputs.
#'            If \code{FALSE}, will be assumed that predictions are going to be called on single-row
#'            regular R matrices.
#' @return The same \code{model} that was passed as input, invisibly, with the desired
#'         configuration stored inside it and available to be used in future calls to
#'         \link{predict.lgb.Booster}.
#' @examples
#' \donttest{
1146
1147
#' \dontshow{setLGBMthreads(2L)}
#' \dontshow{data.table::setDTthreads(1L)}
1148
1149
1150
1151
1152
#' library(lightgbm)
#' data(mtcars)
#' X <- as.matrix(mtcars[, -1L])
#' y <- mtcars[, 1L]
#' dtrain <- lgb.Dataset(X, label = y, params = list(max_bin = 5L))
1153
1154
1155
1156
#' params <- list(
#'   min_data_in_leaf = 2L
#'   , num_threads = 2L
#' )
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
#' model <- lgb.train(
#'   params = params
#'  , data = dtrain
#'  , obj = "regression"
#'  , nrounds = 5L
#'  , verbose = -1L
#' )
#' lgb.configure_fast_predict(model)
#'
#' x_single <- X[11L, , drop = FALSE]
#' predict(model, x_single)
#'
#' # Will not use it if the prediction to be made
#' # is different from what was configured
#' predict(model, x_single, type = "leaf")
#' }
#' @export
lgb.configure_fast_predict <- function(model,
                                       csr = FALSE,
                                       start_iteration = NULL,
                                       num_iteration = NULL,
                                       type = "response",
                                       params = list()) {
1180
  if (!.is_Booster(x = model)) {
1181
    stop("lgb.configure_fast_predict: model should be an ", sQuote("lgb.Booster", q = FALSE))
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
  }
  if (type == "class") {
    stop("type='class' is not supported for 'lgb.configure_fast_predict'. Use 'response' instead.")
  }

  rawscore <- FALSE
  predleaf <- FALSE
  predcontrib <- FALSE
  if (type == "raw") {
    rawscore <- TRUE
  } else if (type == "leaf") {
    predleaf <- TRUE
  } else if (type == "contrib") {
    predcontrib <- TRUE
  }

  if (csr && predcontrib) {
    stop("'lgb.configure_fast_predict' does not support feature contributions for CSR data.")
  }
  model$configure_fast_predict(
    csr = csr
    , start_iteration = start_iteration
    , num_iteration = num_iteration
    , rawscore = rawscore
    , predleaf = predleaf
    , predcontrib = predcontrib
    , params = params
  )
  return(invisible(model))
}

1213
1214
1215
#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
1216
1217
1218
#'
#'              \emph{New in version 4.0.0}
#'
1219
1220
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
1221
#' @return The same input \code{x}, returned as invisible.
1222
1223
1224
1225
#' @export
print.lgb.Booster <- function(x, ...) {
  # nolint start
  handle <- x$.__enclos_env__$private$handle
1226
  handle_is_null <- .is_null_handle(handle)
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240

  if (!handle_is_null) {
    ntrees <- x$current_iter()
    if (ntrees == 1L) {
      cat("LightGBM Model (1 tree)\n")
    } else {
      cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
    }
  } else {
    cat("LightGBM Model\n")
  }

  if (!handle_is_null) {
    obj <- x$params$objective
1241
1242
1243
    if (is.null(obj)) {
      obj <- "(default)"
    }
1244
1245
1246
    if (obj == "none") {
      obj <- "custom"
    }
1247
1248
    num_class <- x$.__enclos_env__$private$num_class
    if (num_class == 1L) {
1249
1250
1251
1252
      cat(sprintf("Objective: %s\n", obj))
    } else {
      cat(sprintf("Objective: %s (%d classes)\n"
          , obj
1253
          , num_class))
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
    }
  } else {
    cat("(Booster handle is invalid)\n")
  }

  if (!handle_is_null) {
    ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
    cat(sprintf("Fitted to dataset with %d columns\n", ncols))
  }
  # nolint end

  return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
1271
1272
1273
#'
#'              \emph{New in version 4.0.0}
#'
1274
1275
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
1276
#' @return The same input \code{object}, returned as invisible.
1277
1278
1279
1280
1281
#' @export
summary.lgb.Booster <- function(object, ...) {
  print(object)
}

1282
1283
#' @name lgb.load
#' @title Load LightGBM model
1284
1285
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
1286
#' @param filename path of model file
1287
#' @param model_str a str containing the model (as a \code{character} or \code{raw} vector)
1288
#'
1289
#' @return lgb.Booster
1290
#'
Guolin Ke's avatar
Guolin Ke committed
1291
#' @examples
1292
#' \donttest{
1293
1294
#' \dontshow{setLGBMthreads(2L)}
#' \dontshow{data.table::setDTthreads(1L)}
1295
1296
1297
1298
1299
1300
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1301
1302
1303
1304
1305
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
1306
#'   , num_threads = 2L
1307
#' )
1308
#' valids <- list(test = dtest)
1309
1310
1311
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1312
#'   , nrounds = 5L
1313
#'   , valids = valids
1314
#'   , early_stopping_rounds = 3L
1315
#' )
1316
1317
1318
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
1319
1320
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
1321
#' }
Guolin Ke's avatar
Guolin Ke committed
1322
#' @export
1323
lgb.load <- function(filename = NULL, model_str = NULL) {
1324

1325
1326
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
1327

1328
1329
1330
1331
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
1332
    filename <- path.expand(filename)
1333
1334
1335
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
1336
1337
    return(invisible(Booster$new(modelfile = filename)))
  }
1338

1339
  if (model_str_provided) {
1340
1341
    if (!is.raw(model_str) && !is.character(model_str)) {
      stop("lgb.load: model_str should be a character/raw vector")
1342
    }
1343
1344
    return(invisible(Booster$new(model_str = model_str)))
  }
1345

1346
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
1347
1348
}

1349
1350
1351
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
1352
#' @param booster Object of class \code{lgb.Booster}
1353
1354
1355
1356
1357
#' @param filename Saved filename
#' @param num_iteration Number of iterations to save, NULL or <= 0 means use best iteration
#' @param start_iteration Index (1-based) of the first boosting round to save.
#'        For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#'        means "save the fifth, sixth, and seventh tree"
1358
#'
James Lamb's avatar
James Lamb committed
1359
1360
#'        \emph{New in version 4.4.0}
#'
1361
#' @return lgb.Booster
1362
#'
Guolin Ke's avatar
Guolin Ke committed
1363
#' @examples
1364
#' \donttest{
1365
1366
#' \dontshow{setLGBMthreads(2L)}
#' \dontshow{data.table::setDTthreads(1L)}
1367
1368
1369
1370
1371
1372
1373
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1374
1375
1376
1377
1378
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
1379
#'   , num_threads = 2L
1380
#' )
1381
#' valids <- list(test = dtest)
1382
1383
1384
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1385
#'   , nrounds = 10L
1386
#'   , valids = valids
1387
#'   , early_stopping_rounds = 5L
1388
#' )
1389
#' lgb.save(model, tempfile(fileext = ".txt"))
1390
#' }
Guolin Ke's avatar
Guolin Ke committed
1391
#' @export
1392
1393
1394
lgb.save <- function(
    booster, filename, num_iteration = NULL, start_iteration = 1L
  ) {
1395

1396
  if (!.is_Booster(x = booster)) {
1397
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster", q = FALSE))
1398
  }
1399

1400
1401
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
1402
  }
1403
  filename <- path.expand(filename)
1404

1405
  # Store booster
1406
1407
1408
1409
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
1410
      , start_iteration = start_iteration
1411
1412
    ))
  )
1413

Guolin Ke's avatar
Guolin Ke committed
1414
1415
}

1416
1417
1418
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
1419
#' @param booster Object of class \code{lgb.Booster}
1420
1421
1422
1423
#' @param num_iteration Number of iterations to be dumped. NULL or <= 0 means use best iteration
#' @param start_iteration Index (1-based) of the first boosting round to dump.
#'        For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#'        means "dump the fifth, sixth, and seventh tree"
1424
#'
James Lamb's avatar
James Lamb committed
1425
1426
#'        \emph{New in version 4.4.0}
#'
Guolin Ke's avatar
Guolin Ke committed
1427
#' @return json format of model
1428
#'
Guolin Ke's avatar
Guolin Ke committed
1429
#' @examples
1430
#' \donttest{
1431
#' library(lightgbm)
1432
1433
#' \dontshow{setLGBMthreads(2L)}
#' \dontshow{data.table::setDTthreads(1L)}
1434
1435
1436
1437
1438
1439
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1440
1441
1442
1443
1444
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
1445
#'   , num_threads = 2L
1446
#' )
1447
#' valids <- list(test = dtest)
1448
1449
1450
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1451
#'   , nrounds = 10L
1452
#'   , valids = valids
1453
#'   , early_stopping_rounds = 5L
1454
#' )
1455
#' json_model <- lgb.dump(model)
1456
#' }
Guolin Ke's avatar
Guolin Ke committed
1457
#' @export
1458
lgb.dump <- function(booster, num_iteration = NULL, start_iteration = 1L) {
1459

1460
  if (!.is_Booster(x = booster)) {
1461
    stop("lgb.dump: booster should be an ", sQuote("lgb.Booster", q = FALSE))
1462
  }
1463

1464
  # Return booster at requested iteration
1465
1466
1467
1468
1469
  return(
    booster$dump_model(
      num_iteration = num_iteration, start_iteration = start_iteration
    )
  )
1470

Guolin Ke's avatar
Guolin Ke committed
1471
1472
}

1473
1474
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
1475
1476
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
1477
#' @param booster Object of class \code{lgb.Booster}
1478
1479
1480
1481
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
1482
#' @param is_err TRUE will return evaluation error instead
1483
#'
1484
#' @return numeric vector of evaluation result
1485
#'
1486
#' @examples
1487
#' \donttest{
1488
1489
#' \dontshow{setLGBMthreads(2L)}
#' \dontshow{data.table::setDTthreads(1L)}
1490
#' # train a regression model
1491
1492
1493
1494
1495
1496
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1497
1498
1499
1500
1501
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
1502
#'   , num_threads = 2L
1503
#' )
1504
#' valids <- list(test = dtest)
1505
1506
1507
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1508
#'   , nrounds = 5L
1509
1510
#'   , valids = valids
#' )
1511
1512
1513
1514
1515
1516
1517
1518
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1519
#' lgb.get.eval.result(model, "test", "l2")
1520
#' }
Guolin Ke's avatar
Guolin Ke committed
1521
#' @export
1522
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1523

1524
  if (!.is_Booster(x = booster)) {
1525
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster", q = FALSE), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1526
  }
1527

1528
1529
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1530
  }
1531

1532
1533
1534
1535
1536
1537
1538
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
1539
      , toString(data_names)
1540
1541
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1542
  }
1543

1544
  # Check if evaluation result is existing
1545
1546
1547
1548
1549
1550
1551
1552
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
1553
      , toString(eval_names)
1554
1555
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1556
  }
1557

1558
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1559

1560
  # Check if error is requested
1561
  if (is_err) {
1562
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1563
  }
1564

1565
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1566
1567
    return(as.numeric(result))
  }
1568

1569
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1570
  iters <- as.integer(iters)
1571
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1572
  iters <- iters - delta
1573

1574
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1575
}