lgb.Booster.R 47.7 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom R6 R6Class
2
#' @importFrom utils modifyList
James Lamb's avatar
James Lamb committed
3
Booster <- R6::R6Class(
4
  classname = "lgb.Booster",
5
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
6
  public = list(
7

8
    best_iter = -1L,
9
    best_score = NA_real_,
10
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
11
    record_evals = list(),
12
    data_processor = NULL,
13

14
15
    # Finalize will free up the handles
    finalize = function() {
16
17
18
19
20
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
21
      return(invisible(NULL))
22
    },
23

24
25
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
26
27
                          train_set = NULL,
                          modelfile = NULL,
28
                          model_str = NULL) {
29

30
      handle <- NULL
31

32
      if (!is.null(train_set)) {
33

34
        if (!.is_Dataset(train_set)) {
35
36
37
38
          stop("lgb.Booster: Can only use lgb.Dataset as training data")
        }
        train_set_handle <- train_set$.__enclos_env__$private$get_handle()
        params <- utils::modifyList(params, train_set$get_params())
39
        params_str <- .params2str(params = params)
40
41
42
43
44
45
        # Store booster handle
        handle <- .Call(
          LGBM_BoosterCreate_R
          , train_set_handle
          , params_str
        )
46

47
48
49
50
51
        # Create private booster information
        private$train_set <- train_set
        private$train_set_version <- train_set$.__enclos_env__$private$version
        private$num_dataset <- 1L
        private$init_predictor <- train_set$.__enclos_env__$private$predictor
52

53
        if (!is.null(private$init_predictor)) {
54

55
56
57
58
59
          # Merge booster
          .Call(
            LGBM_BoosterMerge_R
            , handle
            , private$init_predictor$.__enclos_env__$private$handle
60
          )
61

62
        }
63

64
65
        # Check current iteration
        private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
66

67
      } else if (!is.null(modelfile)) {
68

69
70
71
72
        # Do we have a model file as character?
        if (!is.character(modelfile)) {
          stop("lgb.Booster: Can only use a string as model file path")
        }
73

74
        modelfile <- path.expand(modelfile)
75

76
77
78
79
80
        # Create booster from model
        handle <- .Call(
          LGBM_BoosterCreateFromModelfile_R
          , modelfile
        )
81
        params <- private$get_loaded_param(handle)
82

83
      } else if (!is.null(model_str)) {
84

85
86
87
88
        # Do we have a model_str as character/raw?
        if (!is.raw(model_str) && !is.character(model_str)) {
          stop("lgb.Booster: Can only use a character/raw vector as model_str")
        }
89

90
91
92
93
94
        # Create booster from model
        handle <- .Call(
          LGBM_BoosterLoadModelFromString_R
          , model_str
        )
95

Guolin Ke's avatar
Guolin Ke committed
96
      } else {
97

98
99
100
101
        # Booster non existent
        stop(
          "lgb.Booster: Need at least either training dataset, "
          , "model file, or model_str to create booster instance"
102
        )
103

Guolin Ke's avatar
Guolin Ke committed
104
      }
105

106
107
108
109
110
111
112
113
114
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
      private$num_class <- 1L
      .Call(
        LGBM_BoosterGetNumClasses_R
        , private$handle
        , private$num_class
      )

115
116
      self$params <- params

117
118
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
119
    },
120

121
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
122
    set_train_data_name = function(name) {
123

124
      # Set name
Guolin Ke's avatar
Guolin Ke committed
125
      private$name_train_set <- name
126
      return(invisible(self))
127

Guolin Ke's avatar
Guolin Ke committed
128
    },
129

130
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
131
    add_valid = function(data, name) {
132

133
      if (!.is_Dataset(data)) {
134
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
135
      }
136

Guolin Ke's avatar
Guolin Ke committed
137
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
138
139
140
141
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
142
      }
143

144
145
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
146
      }
147

148
      # Add validation data to booster
149
150
      .Call(
        LGBM_BoosterAddValidData_R
151
152
153
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
154

155
156
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
157
      private$num_dataset <- private$num_dataset + 1L
158
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
159

160
      return(invisible(self))
161

Guolin Ke's avatar
Guolin Ke committed
162
    },
163

164
    reset_parameter = function(params) {
165

166
      if (methods::is(self$params, "list")) {
167
        params <- utils::modifyList(self$params, params)
168
169
      }

170
      params_str <- .params2str(params = params)
171

172
173
      self$restore_handle()

174
175
      .Call(
        LGBM_BoosterResetParameter_R
176
177
178
        , private$handle
        , params_str
      )
179
      self$params <- params
180

181
      return(invisible(self))
182

Guolin Ke's avatar
Guolin Ke committed
183
    },
184

185
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
186
    update = function(train_set = NULL, fobj = NULL) {
187

188
189
190
191
192
193
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

Guolin Ke's avatar
Guolin Ke committed
194
      if (!is.null(train_set)) {
195

196
        if (!.is_Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
197
198
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
199

Guolin Ke's avatar
Guolin Ke committed
200
        if (!identical(train_set$predictor, private$init_predictor)) {
201
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
202
        }
203

204
205
        .Call(
          LGBM_BoosterResetTrainingData_R
206
207
208
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
209

210
        private$train_set <- train_set
211
        private$train_set_version <- train_set$.__enclos_env__$private$version
212

Guolin Ke's avatar
Guolin Ke committed
213
      }
214

215
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
216
      if (is.null(fobj)) {
217
218
219
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
220
        # Boost iteration from known objective
221
222
        .Call(
          LGBM_BoosterUpdateOneIter_R
223
224
          , private$handle
        )
225

Guolin Ke's avatar
Guolin Ke committed
226
      } else {
227

228
229
230
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
231
        if (!private$set_objective_to_none) {
232
          self$reset_parameter(params = list(objective = "none"))
233
          private$set_objective_to_none <- TRUE
234
        }
235
        # Perform objective calculation
236
237
        preds <- private$inner_predict(1L)
        gpair <- fobj(preds, private$train_set)
238

239
        # Check for gradient and hessian as list
240
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
241
          stop("lgb.Booster.update: custom objective should
242
243
            return a list with attributes (hess, grad)")
        }
244

245
246
247
248
249
250
251
252
253
254
255
        # Check grad and hess have the right shape
        n_grad <- length(gpair$grad)
        n_hess <- length(gpair$hess)
        n_preds <- length(preds)
        if (n_grad != n_preds) {
          stop(sprintf("Expected custom objective function to return grad with length %d, got %d.", n_preds, n_grad))
        }
        if (n_hess != n_preds) {
          stop(sprintf("Expected custom objective function to return hess with length %d, got %d.", n_preds, n_hess))
        }

256
        # Return custom boosting gradient/hessian
257
258
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
259
260
261
          , private$handle
          , gpair$grad
          , gpair$hess
262
          , n_preds
263
        )
264

Guolin Ke's avatar
Guolin Ke committed
265
      }
266

267
      # Loop through each iteration
268
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
269
270
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
271

272
      return(invisible(self))
273

Guolin Ke's avatar
Guolin Ke committed
274
    },
275

276
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
277
    rollback_one_iter = function() {
278

279
280
      self$restore_handle()

281
282
      .Call(
        LGBM_BoosterRollbackOneIter_R
283
284
        , private$handle
      )
285

286
      # Loop through each iteration
287
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
288
289
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
290

291
      return(invisible(self))
292

Guolin Ke's avatar
Guolin Ke committed
293
    },
294

295
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
296
    current_iter = function() {
297

298
299
      self$restore_handle()

300
      cur_iter <- 0L
301
302
303
304
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
305
      )
306
      return(cur_iter)
307

Guolin Ke's avatar
Guolin Ke committed
308
    },
309

310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    # Number of trees per iteration
    num_trees_per_iter = function() {

      self$restore_handle()

      trees_per_iter <- 1L
      .Call(
        LGBM_BoosterNumModelPerIteration_R
        , private$handle
        , trees_per_iter
      )
      return(trees_per_iter)

    },

    # Total number of trees
    num_trees = function() {

      self$restore_handle()

      ntrees <- 0L
      .Call(
        LGBM_BoosterNumberOfTotalModel_R
        , private$handle
        , ntrees
      )
      return(ntrees)

    },

    # Number of iterations (= rounds)
    num_iter = function() {

      ntrees <- self$num_trees()
      trees_per_iter <- self$num_trees_per_iter()

      return(ntrees / trees_per_iter)

    },

350
    # Get upper bound
351
    upper_bound = function() {
352

353
354
      self$restore_handle()

355
      upper_bound <- 0.0
356
357
358
359
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
360
      )
361
      return(upper_bound)
362
363
364
365

    },

    # Get lower bound
366
    lower_bound = function() {
367

368
369
      self$restore_handle()

370
      lower_bound <- 0.0
371
372
373
374
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
375
      )
376
      return(lower_bound)
377
378
379

    },

380
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
381
    eval = function(data, name, feval = NULL) {
382

383
      if (!.is_Dataset(data)) {
384
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
385
      }
386

387
      # Check for identical data
388
      data_idx <- 0L
389
      if (identical(data, private$train_set)) {
390
        data_idx <- 1L
391
      } else {
392

393
        # Check for validation data
394
        if (length(private$valid_sets) > 0L) {
395

396
          for (i in seq_along(private$valid_sets)) {
397

398
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
399
            if (identical(data, private$valid_sets[[i]])) {
400

401
              # Found identical data, skip
402
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
403
              break
404

Guolin Ke's avatar
Guolin Ke committed
405
            }
406

Guolin Ke's avatar
Guolin Ke committed
407
          }
408

Guolin Ke's avatar
Guolin Ke committed
409
        }
410

Guolin Ke's avatar
Guolin Ke committed
411
      }
412

413
      # Check if evaluation was not done
414
      if (data_idx == 0L) {
415

416
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
417
418
        self$add_valid(data, name)
        data_idx <- private$num_dataset
419

Guolin Ke's avatar
Guolin Ke committed
420
      }
421

422
      # Evaluate data
423
424
425
426
427
428
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
429
      )
430

Guolin Ke's avatar
Guolin Ke committed
431
    },
432

433
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
434
    eval_train = function(feval = NULL) {
435
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
436
    },
437

438
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
439
    eval_valid = function(feval = NULL) {
440

441
      ret <- list()
442

443
      if (length(private$valid_sets) <= 0L) {
444
445
        return(ret)
      }
446

447
      for (i in seq_along(private$valid_sets)) {
448
449
        ret <- append(
          x = ret
450
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
451
        )
Guolin Ke's avatar
Guolin Ke committed
452
      }
453

454
      return(ret)
455

Guolin Ke's avatar
Guolin Ke committed
456
    },
457

458
    # Save model
459
460
461
462
463
464
    save_model = function(
      filename
      , num_iteration = NULL
      , feature_importance_type = 0L
      , start_iteration = 1L
    ) {
465

466
467
      self$restore_handle()

468
469
470
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
471

472
473
      filename <- path.expand(filename)

474
475
      .Call(
        LGBM_BoosterSaveModel_R
476
477
        , private$handle
        , as.integer(num_iteration)
478
        , as.integer(feature_importance_type)
479
        , filename
480
        , as.integer(start_iteration) - 1L  # Turn to 0-based
481
      )
482

483
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
484
    },
485

486
487
488
489
490
491
    save_model_to_string = function(
      num_iteration = NULL
      , feature_importance_type = 0L
      , as_char = TRUE
      , start_iteration = 1L
    ) {
492
493

      self$restore_handle()
494

495
496
497
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
498

499
      model_str <- .Call(
500
          LGBM_BoosterSaveModelToString_R
501
502
503
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
504
          , as.integer(start_iteration) - 1L  # Turn to 0-based
505
506
      )

507
508
509
510
      if (as_char) {
        model_str <- rawToChar(model_str)
      }

511
      return(model_str)
512

513
    },
514

515
    # Dump model in memory
516
517
518
    dump_model = function(
      num_iteration = NULL, feature_importance_type = 0L, start_iteration = 1L
    ) {
519

520
521
      self$restore_handle()

522
523
524
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
525

526
      model_str <- .Call(
527
528
529
530
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
531
        , as.integer(start_iteration) - 1L  # Turn to 0-based
532
533
      )

534
      return(model_str)
535

Guolin Ke's avatar
Guolin Ke committed
536
    },
537

538
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
539
    predict = function(data,
540
                       start_iteration = NULL,
541
542
543
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
544
                       predcontrib = FALSE,
545
                       header = FALSE,
546
                       params = list()) {
547

548
549
      self$restore_handle()

550
551
552
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
553

554
555
556
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
557

558
559
560
561
562
563
564
565
      # possibly override keyword arguments with parameters
      #
      # NOTE: this length() check minimizes the latency introduced by these checks,
      #       for the common case where params is empty
      #
      # NOTE: doing this here instead of in Predictor$predict() to keep
      #       Predictor$predict() as fast as possible
      if (length(params) > 0L) {
566
        params <- .check_wrapper_param(
567
568
569
570
          main_param_name = "predict_raw_score"
          , params = params
          , alternative_kwarg_value = rawscore
        )
571
        params <- .check_wrapper_param(
572
573
574
575
          main_param_name = "predict_leaf_index"
          , params = params
          , alternative_kwarg_value = predleaf
        )
576
        params <- .check_wrapper_param(
577
578
579
580
581
582
583
584
585
          main_param_name = "predict_contrib"
          , params = params
          , alternative_kwarg_value = predcontrib
        )
        rawscore <- params[["predict_raw_score"]]
        predleaf <- params[["predict_leaf_index"]]
        predcontrib <- params[["predict_contrib"]]
      }

586
      # Predict on new data
587
588
589
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
590
        , fast_predict_config = private$fast_predict_config
591
      )
592
593
      return(
        predictor$predict(
594
595
596
597
598
599
600
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
601
        )
602
      )
603

604
    },
605

606
607
    # Transform into predictor
    to_predictor = function() {
608
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
609
    },
610

611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
    configure_fast_predict = function(csr = FALSE,
                                      start_iteration = NULL,
                                      num_iteration = NULL,
                                      rawscore = FALSE,
                                      predleaf = FALSE,
                                      predcontrib = FALSE,
                                      params = list()) {

      self$restore_handle()
      ncols <- .Call(LGBM_BoosterGetNumFeature_R, private$handle)

      if (is.null(num_iteration)) {
        num_iteration <- -1L
      }
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }

      if (!csr) {
        fun <- LGBM_BoosterPredictForMatSingleRowFastInit_R
      } else {
        fun <- LGBM_BoosterPredictForCSRSingleRowFastInit_R
      }

      fast_handle <- .Call(
        fun
        , private$handle
        , ncols
        , rawscore
        , predleaf
        , predcontrib
        , start_iteration
        , num_iteration
644
        , .params2str(params = params)
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
      )

      private$fast_predict_config <- list(
        handle = fast_handle
        , csr = as.logical(csr)
        , ncols = ncols
        , start_iteration = start_iteration
        , num_iteration = num_iteration
        , rawscore = as.logical(rawscore)
        , predleaf = as.logical(predleaf)
        , predcontrib = as.logical(predcontrib)
        , params = params
      )

      return(invisible(NULL))
    },

662
663
    # Used for serialization
    raw = NULL,
664

665
666
667
668
669
670
    # Store serialized raw bytes in model object
    save_raw = function() {
      if (is.null(self$raw)) {
        self$raw <- self$save_model_to_string(NULL, as_char = FALSE)
      }
      return(invisible(NULL))
671

672
    },
673

674
675
    drop_raw = function() {
      self$raw <- NULL
676
      return(invisible(NULL))
677
    },
678

679
    check_null_handle = function() {
680
      return(.is_null_handle(private$handle))
681
682
683
684
685
686
687
688
689
690
691
692
693
694
    },

    restore_handle = function() {
      if (self$check_null_handle()) {
        if (is.null(self$raw)) {
          .Call(LGBM_NullBoosterHandleError_R)
        }
        private$handle <- .Call(LGBM_BoosterLoadModelFromString_R, self$raw)
      }
      return(invisible(NULL))
    },

    get_handle = function() {
      return(private$handle)
695
    }
696

Guolin Ke's avatar
Guolin Ke committed
697
698
  ),
  private = list(
699
700
701
702
703
704
705
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
706
707
    num_class = 1L,
    num_dataset = 0L,
708
709
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
710
    higher_better_inner_eval = NULL,
711
    set_objective_to_none = FALSE,
712
    train_set_version = 0L,
713
    fast_predict_config = list(),
714
715
    # Predict data
    inner_predict = function(idx) {
716

717
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
718
      data_name <- private$name_train_set
719

720
721
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
722
      }
723

724
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
725
726
727
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
728

729
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
730
      if (is.null(private$predict_buffer[[data_name]])) {
731

732
        # Store predictions
733
        npred <- 0L
734
735
        .Call(
          LGBM_BoosterGetNumPredict_R
736
          , private$handle
737
          , as.integer(idx - 1L)
738
          , npred
739
        )
740
        private$predict_buffer[[data_name]] <- numeric(npred)
741

Guolin Ke's avatar
Guolin Ke committed
742
      }
743

744
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
745
      if (!private$is_predicted_cur_iter[[idx]]) {
746

747
        # Use buffer
748
749
        .Call(
          LGBM_BoosterGetPredict_R
750
          , private$handle
751
          , as.integer(idx - 1L)
752
          , private$predict_buffer[[data_name]]
753
        )
Guolin Ke's avatar
Guolin Ke committed
754
755
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
756

757
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
758
    },
759

760
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
761
    get_eval_info = function() {
762

Guolin Ke's avatar
Guolin Ke committed
763
      if (is.null(private$eval_names)) {
764
        eval_names <- .Call(
765
          LGBM_BoosterGetEvalNames_R
766
767
          , private$handle
        )
768

769
        if (length(eval_names) > 0L) {
770

771
          # Parse and store privately names
772
          private$eval_names <- eval_names
773
774
775

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
776
          metric_names <- gsub("@.*", "", eval_names)
777
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
778

Guolin Ke's avatar
Guolin Ke committed
779
        }
780

Guolin Ke's avatar
Guolin Ke committed
781
      }
782

783
      return(private$eval_names)
784

Guolin Ke's avatar
Guolin Ke committed
785
    },
786

787
788
789
790
791
792
793
794
795
796
797
798
799
800
    get_loaded_param = function(handle) {
      params_str <- .Call(
        LGBM_BoosterGetLoadedParam_R
        , handle
      )
      params <- jsonlite::fromJSON(params_str)
      if ("interaction_constraints" %in% names(params)) {
        params[["interaction_constraints"]] <- lapply(params[["interaction_constraints"]], function(x) x + 1L)
      }

      return(params)

    },

Guolin Ke's avatar
Guolin Ke committed
801
    inner_eval = function(data_name, data_idx, feval = NULL) {
802

803
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
804
805
806
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
807

808
809
      self$restore_handle()

Guolin Ke's avatar
Guolin Ke committed
810
      private$get_eval_info()
811

Guolin Ke's avatar
Guolin Ke committed
812
      ret <- list()
813

814
      if (length(private$eval_names) > 0L) {
815

816
817
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
818
819
        .Call(
          LGBM_BoosterGetEval_R
820
          , private$handle
821
          , as.integer(data_idx - 1L)
822
          , tmp_vals
823
        )
824

825
        for (i in seq_along(private$eval_names)) {
826

827
828
829
830
831
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
832
          res$higher_better <- private$higher_better_inner_eval[i]
833
          ret <- append(ret, list(res))
834

Guolin Ke's avatar
Guolin Ke committed
835
        }
836

Guolin Ke's avatar
Guolin Ke committed
837
      }
838

839
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
840
      if (!is.null(feval)) {
841

842
        # Check if evaluation metric is a function
843
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
844
845
          stop("lgb.Booster.eval: feval should be a function")
        }
846

Guolin Ke's avatar
Guolin Ke committed
847
        data <- private$train_set
848

849
        # Check if data to assess is existing differently
850
851
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
852
        }
853

854
        # Perform function evaluation
855
        res <- feval(private$inner_predict(data_idx), data)
856

857
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
858
859
860
          stop(
            "lgb.Booster.eval: custom eval function should return a list with attribute (name, value, higher_better)"
          )
861
        }
862

863
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
864
        res$data_name <- data_name
865
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
866
      }
867

868
      return(ret)
869

Guolin Ke's avatar
Guolin Ke committed
870
    }
871

Guolin Ke's avatar
Guolin Ke committed
872
873
874
  )
)

875
#' @name lgb_predict_shared_params
876
877
878
879
880
#' @param type Type of prediction to output. Allowed types are:\itemize{
#'             \item \code{"response"}: will output the predicted score according to the objective function being
#'                   optimized (depending on the link function that the objective uses), after applying any necessary
#'                   transformations - for example, for \code{objective="binary"}, it will output class probabilities.
#'             \item \code{"class"}: for classification objectives, will output the class with the highest predicted
881
882
883
#'                   probability. For other objectives, will output the same as "response". Note that \code{"class"} is
#'                   not a supported type for \link{lgb.configure_fast_predict} (see the documentation of that function
#'                   for more details).
884
885
886
887
888
889
890
#'             \item \code{"raw"}: will output the non-transformed numbers (sum of predictions from boosting iterations'
#'                   results) from which the "response" number is produced for a given objective function - for example,
#'                   for \code{objective="binary"}, this corresponds to log-odds. For many objectives such as
#'                   "regression", since no transformation is applied, the output will be the same as for "response".
#'             \item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls
#'                   in each tree in the model, outputted as integers, with one column per tree.
#'             \item \code{"contrib"}: will return the per-feature contributions for each prediction, including an
891
#'                   intercept (each feature will produce one column).
892
893
894
895
#'             }
#'
#'             Note that, if using custom objectives, types "class" and "response" will not be available and will
#'             default towards using "raw" instead.
896
897
898
899
900
#'
#'             If the model was fit through function \link{lightgbm} and it was passed a factor as labels,
#'             passing the prediction type through \code{params} instead of through this argument might
#'             result in factor levels for classification objectives not being applied correctly to the
#'             resulting output.
901
902
903
#'
#'             \emph{New in version 4.0.0}
#'
904
905
906
907
908
909
910
911
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
912
913
914
#' @param params a list of additional named parameters. See
#'               \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#'               the "Predict Parameters" section of the documentation} for a list of parameters and
915
916
#'               valid values. Where these conflict with the values of keyword arguments to this function,
#'               the values in \code{params} take precedence.
917
918
919
920
921
NULL

#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
922
923
924
#'
#'              \emph{New in version 4.0.0}
#'
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
#' @details If the model object has been configured for fast single-row predictions through
#'          \link{lgb.configure_fast_predict}, this function will use the prediction parameters
#'          that were configured for it - as such, extra prediction parameters should not be passed
#'          here, otherwise the configuration will be ignored and the slow route will be taken.
#' @inheritParams lgb_predict_shared_params
#' @param object Object of class \code{lgb.Booster}
#' @param newdata a \code{matrix} object, a \code{dgCMatrix}, a \code{dgRMatrix} object, a \code{dsparseVector} object,
#'                or a character representing a path to a text file (CSV, TSV, or LibSVM).
#'
#'                For sparse inputs, if predictions are only going to be made for a single row, it will be faster to
#'                use CSR format, in which case the data may be passed as either a single-row CSR matrix (class
#'                \code{dgRMatrix} from package \code{Matrix}) or as a sparse numeric vector (class
#'                \code{dsparseVector} from package \code{Matrix}).
#'
#'                If single-row predictions are going to be performed frequently, it is recommended to
#'                pre-configure the model object for fast single-row sparse predictions through function
#'                \link{lgb.configure_fast_predict}.
942
943
944
#'
#'                \emph{Changed from 'data', in version 4.0.0}
#'
945
#' @param header only used for prediction for text file. True if text file has header
946
#' @param ... ignored
947
#' @return For prediction types that are meant to always return one output per observation (e.g. when predicting
948
949
#'         \code{type="response"} or \code{type="raw"} on a binary classification or regression objective), will
#'         return a vector with one element per row in \code{newdata}.
950
#'
951
#'         For prediction types that are meant to return more than one output per observation (e.g. when predicting
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
#'         \code{type="response"} or \code{type="raw"} on a multi-class objective, or when predicting
#'         \code{type="leaf"}, regardless of objective), will return a matrix with one row per observation in
#'         \code{newdata} and one column per output.
#'
#'         For \code{type="leaf"} predictions, will return a matrix with one row per observation in \code{newdata}
#'         and one column per tree. Note that for multiclass objectives, LightGBM trains one tree per class at each
#'         boosting iteration. That means that, for example, for a multiclass model with 3 classes, the leaf
#'         predictions for the first class can be found in columns 1, 4, 7, 10, etc.
#'
#'         For \code{type="contrib"}, will return a matrix of SHAP values with one row per observation in
#'         \code{newdata} and columns corresponding to features. For regression, ranking, cross-entropy, and binary
#'         classification objectives, this matrix contains one column per feature plus a final column containing the
#'         Shapley base value. For multiclass objectives, this matrix will represent \code{num_classes} such matrices,
#'         in the order "feature contributions for first class, feature contributions for second class, feature
#'         contributions for third class, etc.".
967
#'
968
969
970
971
972
#'         If the model was fit through function \link{lightgbm} and it was passed a factor as labels, predictions
#'         returned from this function will retain the factor levels (either as values for \code{type="class"}, or
#'         as column names for \code{type="response"} and \code{type="raw"} for multi-class objectives). Note that
#'         passing the requested prediction type under \code{params} instead of through \code{type} might result in
#'         the factor levels not being present in the output.
Guolin Ke's avatar
Guolin Ke committed
973
#' @examples
974
#' \donttest{
975
976
#' \dontshow{setLGBMthreads(2L)}
#' \dontshow{data.table::setDTthreads(1L)}
977
978
979
980
981
982
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
983
984
985
986
987
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
988
#'   , num_threads = 2L
989
#' )
990
#' valids <- list(test = dtest)
991
992
993
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
994
#'   , nrounds = 5L
995
996
#'   , valids = valids
#' )
997
#' preds <- predict(model, test$data)
998
999
#'
#' # pass other prediction parameters
1000
#' preds <- predict(
1001
1002
1003
1004
1005
1006
#'     model,
#'     test$data,
#'     params = list(
#'         predict_disable_shape_check = TRUE
#'    )
#' )
1007
#' }
1008
#' @importFrom utils modifyList
Guolin Ke's avatar
Guolin Ke committed
1009
#' @export
James Lamb's avatar
James Lamb committed
1010
predict.lgb.Booster <- function(object,
1011
                                newdata,
1012
                                type = "response",
1013
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
1014
1015
                                num_iteration = NULL,
                                header = FALSE,
1016
                                params = list(),
James Lamb's avatar
James Lamb committed
1017
                                ...) {
1018

1019
  if (!.is_Booster(x = object)) {
1020
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
1021
  }
1022

1023
1024
  additional_params <- list(...)
  if (length(additional_params) > 0L) {
1025
1026
    additional_params_names <- names(additional_params)
    if ("reshape" %in% additional_params_names) {
1027
1028
      stop("'reshape' argument is no longer supported.")
    }
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042

    old_args_for_type <- list(
      "rawscore" = "raw"
      , "predleaf" = "leaf"
      , "predcontrib" = "contrib"
    )
    for (arg in names(old_args_for_type)) {
      if (arg %in% additional_params_names) {
        stop(sprintf("Argument '%s' is no longer supported. Use type='%s' instead."
                     , arg
                     , old_args_for_type[[arg]]))
      }
    }

1043
1044
    warning(paste0(
      "predict.lgb.Booster: Found the following passed through '...': "
1045
      , toString(names(additional_params))
1046
      , ". These are ignored. Use argument 'params' instead."
1047
1048
1049
    ))
  }

1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
  if (!is.null(object$params$objective) && object$params$objective == "none" && type %in% c("class", "response")) {
    warning("Prediction types 'class' and 'response' are not supported for custom objectives.")
    type <- "raw"
  }

  rawscore <- FALSE
  predleaf <- FALSE
  predcontrib <- FALSE
  if (type == "raw") {
    rawscore <- TRUE
  } else if (type == "leaf") {
    predleaf <- TRUE
  } else if (type == "contrib") {
    predcontrib <- TRUE
  }

  pred <- object$predict(
    data = newdata
    , start_iteration = start_iteration
    , num_iteration = num_iteration
    , rawscore = rawscore
    , predleaf =  predleaf
    , predcontrib =  predcontrib
    , header = header
    , params = params
1075
  )
1076
  if (type == "class") {
1077
    if (object$params$objective %in% .BINARY_OBJECTIVES()) {
1078
      pred <- as.integer(pred >= 0.5)
1079
    } else if (object$params$objective %in% .MULTICLASS_OBJECTIVES()) {
1080
1081
1082
      pred <- max.col(pred) - 1L
    }
  }
1083
1084
1085
1086
1087
1088
  if (!is.null(object$data_processor)) {
    pred <- object$data_processor$process_predictions(
      pred = pred
      , type = type
    )
  }
1089
  return(pred)
Guolin Ke's avatar
Guolin Ke committed
1090
1091
}

1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
#' @title Configure Fast Single-Row Predictions
#' @description Pre-configures a LightGBM model object to produce fast single-row predictions
#'              for a given input data type, prediction type, and parameters.
#' @details Calling this function multiple times with different parameters might not override
#'          the previous configuration and might trigger undefined behavior.
#'
#'          Any saved configuration for fast predictions might be lost after making a single-row
#'          prediction of a different type than what was configured (except for types "response" and
#'          "class", which can be switched between each other at any time without losing the configuration).
#'
#'          In some situations, setting a fast prediction configuration for one type of prediction
#'          might cause the prediction function to keep using that configuration for single-row
#'          predictions even if the requested type of prediction is different from what was configured.
#'
#'          Note that this function will not accept argument \code{type="class"} - for such cases, one
#'          can pass \code{type="response"} to this function and then \code{type="class"} to the
#'          \code{predict} function - the fast configuration will not be lost or altered if the switch
#'          is between "response" and "class".
#'
#'          The configuration does not survive de-serializations, so it has to be generated
#'          anew in every R process that is going to use it (e.g. if loading a model object
#'          through \code{readRDS}, whatever configuration was there previously will be lost).
#'
#'          Requesting a different prediction type or passing parameters to \link{predict.lgb.Booster}
#'          will cause it to ignore the fast-predict configuration and take the slow route instead
1117
#'          (but be aware that an existing configuration might not always be overridden by supplying
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
#'          different parameters or prediction type, so make sure to check that the output is what
#'          was expected when a prediction is to be made on a single row for something different than
#'          what is configured).
#'
#'          Note that, if configuring a non-default prediction type (such as leaf indices),
#'          then that type must also be passed in the call to \link{predict.lgb.Booster} in
#'          order for it to use the configuration. This also applies for \code{start_iteration}
#'          and \code{num_iteration}, but \bold{the \code{params} list must be empty} in the call to \code{predict}.
#'
#'          Predictions about feature contributions do not allow a fast route for CSR inputs,
#'          and as such, this function will produce an error if passing \code{csr=TRUE} and
#'          \code{type = "contrib"} together.
#' @inheritParams lgb_predict_shared_params
1131
#' @param model LightGBM model object (class \code{lgb.Booster}).
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
#'
#'              \bold{The object will be modified in-place}.
#' @param csr Whether the prediction function is going to be called on sparse CSR inputs.
#'            If \code{FALSE}, will be assumed that predictions are going to be called on single-row
#'            regular R matrices.
#' @return The same \code{model} that was passed as input, invisibly, with the desired
#'         configuration stored inside it and available to be used in future calls to
#'         \link{predict.lgb.Booster}.
#' @examples
#' \donttest{
1142
1143
#' \dontshow{setLGBMthreads(2L)}
#' \dontshow{data.table::setDTthreads(1L)}
1144
1145
1146
1147
1148
#' library(lightgbm)
#' data(mtcars)
#' X <- as.matrix(mtcars[, -1L])
#' y <- mtcars[, 1L]
#' dtrain <- lgb.Dataset(X, label = y, params = list(max_bin = 5L))
1149
1150
1151
1152
#' params <- list(
#'   min_data_in_leaf = 2L
#'   , num_threads = 2L
#' )
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
#' model <- lgb.train(
#'   params = params
#'  , data = dtrain
#'  , obj = "regression"
#'  , nrounds = 5L
#'  , verbose = -1L
#' )
#' lgb.configure_fast_predict(model)
#'
#' x_single <- X[11L, , drop = FALSE]
#' predict(model, x_single)
#'
#' # Will not use it if the prediction to be made
#' # is different from what was configured
#' predict(model, x_single, type = "leaf")
#' }
#' @export
lgb.configure_fast_predict <- function(model,
                                       csr = FALSE,
                                       start_iteration = NULL,
                                       num_iteration = NULL,
                                       type = "response",
                                       params = list()) {
1176
  if (!.is_Booster(x = model)) {
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
    stop("lgb.configure_fast_predict: model should be an ", sQuote("lgb.Booster"))
  }
  if (type == "class") {
    stop("type='class' is not supported for 'lgb.configure_fast_predict'. Use 'response' instead.")
  }

  rawscore <- FALSE
  predleaf <- FALSE
  predcontrib <- FALSE
  if (type == "raw") {
    rawscore <- TRUE
  } else if (type == "leaf") {
    predleaf <- TRUE
  } else if (type == "contrib") {
    predcontrib <- TRUE
  }

  if (csr && predcontrib) {
    stop("'lgb.configure_fast_predict' does not support feature contributions for CSR data.")
  }
  model$configure_fast_predict(
    csr = csr
    , start_iteration = start_iteration
    , num_iteration = num_iteration
    , rawscore = rawscore
    , predleaf = predleaf
    , predcontrib = predcontrib
    , params = params
  )
  return(invisible(model))
}

1209
1210
1211
#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
1212
1213
1214
#'
#'              \emph{New in version 4.0.0}
#'
1215
1216
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
1217
#' @return The same input \code{x}, returned as invisible.
1218
1219
1220
1221
#' @export
print.lgb.Booster <- function(x, ...) {
  # nolint start
  handle <- x$.__enclos_env__$private$handle
1222
  handle_is_null <- .is_null_handle(handle)
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239

  if (!handle_is_null) {
    ntrees <- x$current_iter()
    if (ntrees == 1L) {
      cat("LightGBM Model (1 tree)\n")
    } else {
      cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
    }
  } else {
    cat("LightGBM Model\n")
  }

  if (!handle_is_null) {
    obj <- x$params$objective
    if (obj == "none") {
      obj <- "custom"
    }
1240
1241
    num_class <- x$.__enclos_env__$private$num_class
    if (num_class == 1L) {
1242
1243
1244
1245
      cat(sprintf("Objective: %s\n", obj))
    } else {
      cat(sprintf("Objective: %s (%d classes)\n"
          , obj
1246
          , num_class))
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
    }
  } else {
    cat("(Booster handle is invalid)\n")
  }

  if (!handle_is_null) {
    ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
    cat(sprintf("Fitted to dataset with %d columns\n", ncols))
  }
  # nolint end

  return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
1264
1265
1266
#'
#'              \emph{New in version 4.0.0}
#'
1267
1268
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
1269
#' @return The same input \code{object}, returned as invisible.
1270
1271
1272
1273
1274
#' @export
summary.lgb.Booster <- function(object, ...) {
  print(object)
}

1275
1276
#' @name lgb.load
#' @title Load LightGBM model
1277
1278
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
1279
#' @param filename path of model file
1280
#' @param model_str a str containing the model (as a \code{character} or \code{raw} vector)
1281
#'
1282
#' @return lgb.Booster
1283
#'
Guolin Ke's avatar
Guolin Ke committed
1284
#' @examples
1285
#' \donttest{
1286
1287
#' \dontshow{setLGBMthreads(2L)}
#' \dontshow{data.table::setDTthreads(1L)}
1288
1289
1290
1291
1292
1293
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1294
1295
1296
1297
1298
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
1299
#'   , num_threads = 2L
1300
#' )
1301
#' valids <- list(test = dtest)
1302
1303
1304
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1305
#'   , nrounds = 5L
1306
#'   , valids = valids
1307
#'   , early_stopping_rounds = 3L
1308
#' )
1309
1310
1311
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
1312
1313
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
1314
#' }
Guolin Ke's avatar
Guolin Ke committed
1315
#' @export
1316
lgb.load <- function(filename = NULL, model_str = NULL) {
1317

1318
1319
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
1320

1321
1322
1323
1324
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
1325
    filename <- path.expand(filename)
1326
1327
1328
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
1329
1330
    return(invisible(Booster$new(modelfile = filename)))
  }
1331

1332
  if (model_str_provided) {
1333
1334
    if (!is.raw(model_str) && !is.character(model_str)) {
      stop("lgb.load: model_str should be a character/raw vector")
1335
    }
1336
1337
    return(invisible(Booster$new(model_str = model_str)))
  }
1338

1339
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
1340
1341
}

1342
1343
1344
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
1345
#' @param booster Object of class \code{lgb.Booster}
1346
1347
1348
1349
1350
#' @param filename Saved filename
#' @param num_iteration Number of iterations to save, NULL or <= 0 means use best iteration
#' @param start_iteration Index (1-based) of the first boosting round to save.
#'        For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#'        means "save the fifth, sixth, and seventh tree"
1351
#'
James Lamb's avatar
James Lamb committed
1352
1353
#'        \emph{New in version 4.4.0}
#'
1354
#' @return lgb.Booster
1355
#'
Guolin Ke's avatar
Guolin Ke committed
1356
#' @examples
1357
#' \donttest{
1358
1359
#' \dontshow{setLGBMthreads(2L)}
#' \dontshow{data.table::setDTthreads(1L)}
1360
1361
1362
1363
1364
1365
1366
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1367
1368
1369
1370
1371
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
1372
#'   , num_threads = 2L
1373
#' )
1374
#' valids <- list(test = dtest)
1375
1376
1377
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1378
#'   , nrounds = 10L
1379
#'   , valids = valids
1380
#'   , early_stopping_rounds = 5L
1381
#' )
1382
#' lgb.save(model, tempfile(fileext = ".txt"))
1383
#' }
Guolin Ke's avatar
Guolin Ke committed
1384
#' @export
1385
1386
1387
lgb.save <- function(
    booster, filename, num_iteration = NULL, start_iteration = 1L
  ) {
1388

1389
  if (!.is_Booster(x = booster)) {
1390
1391
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1392

1393
1394
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
1395
  }
1396
  filename <- path.expand(filename)
1397

1398
  # Store booster
1399
1400
1401
1402
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
1403
      , start_iteration = start_iteration
1404
1405
    ))
  )
1406

Guolin Ke's avatar
Guolin Ke committed
1407
1408
}

1409
1410
1411
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
1412
#' @param booster Object of class \code{lgb.Booster}
1413
1414
1415
1416
#' @param num_iteration Number of iterations to be dumped. NULL or <= 0 means use best iteration
#' @param start_iteration Index (1-based) of the first boosting round to dump.
#'        For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#'        means "dump the fifth, sixth, and seventh tree"
1417
#'
James Lamb's avatar
James Lamb committed
1418
1419
#'        \emph{New in version 4.4.0}
#'
Guolin Ke's avatar
Guolin Ke committed
1420
#' @return json format of model
1421
#'
Guolin Ke's avatar
Guolin Ke committed
1422
#' @examples
1423
#' \donttest{
1424
#' library(lightgbm)
1425
1426
#' \dontshow{setLGBMthreads(2L)}
#' \dontshow{data.table::setDTthreads(1L)}
1427
1428
1429
1430
1431
1432
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1433
1434
1435
1436
1437
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
1438
#'   , num_threads = 2L
1439
#' )
1440
#' valids <- list(test = dtest)
1441
1442
1443
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1444
#'   , nrounds = 10L
1445
#'   , valids = valids
1446
#'   , early_stopping_rounds = 5L
1447
#' )
1448
#' json_model <- lgb.dump(model)
1449
#' }
Guolin Ke's avatar
Guolin Ke committed
1450
#' @export
1451
lgb.dump <- function(booster, num_iteration = NULL, start_iteration = 1L) {
1452

1453
  if (!.is_Booster(x = booster)) {
1454
    stop("lgb.dump: booster should be an ", sQuote("lgb.Booster"))
1455
  }
1456

1457
  # Return booster at requested iteration
1458
1459
1460
1461
1462
  return(
    booster$dump_model(
      num_iteration = num_iteration, start_iteration = start_iteration
    )
  )
1463

Guolin Ke's avatar
Guolin Ke committed
1464
1465
}

1466
1467
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
1468
1469
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
1470
#' @param booster Object of class \code{lgb.Booster}
1471
1472
1473
1474
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
1475
#' @param is_err TRUE will return evaluation error instead
1476
#'
1477
#' @return numeric vector of evaluation result
1478
#'
1479
#' @examples
1480
#' \donttest{
1481
1482
#' \dontshow{setLGBMthreads(2L)}
#' \dontshow{data.table::setDTthreads(1L)}
1483
#' # train a regression model
1484
1485
1486
1487
1488
1489
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1490
1491
1492
1493
1494
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
1495
#'   , num_threads = 2L
1496
#' )
1497
#' valids <- list(test = dtest)
1498
1499
1500
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1501
#'   , nrounds = 5L
1502
1503
#'   , valids = valids
#' )
1504
1505
1506
1507
1508
1509
1510
1511
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1512
#' lgb.get.eval.result(model, "test", "l2")
1513
#' }
Guolin Ke's avatar
Guolin Ke committed
1514
#' @export
1515
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1516

1517
  if (!.is_Booster(x = booster)) {
1518
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1519
  }
1520

1521
1522
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1523
  }
1524

1525
1526
1527
1528
1529
1530
1531
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
1532
      , toString(data_names)
1533
1534
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1535
  }
1536

1537
  # Check if evaluation result is existing
1538
1539
1540
1541
1542
1543
1544
1545
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
1546
      , toString(eval_names)
1547
1548
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1549
  }
1550

1551
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1552

1553
  # Check if error is requested
1554
  if (is_err) {
1555
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1556
  }
1557

1558
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1559
1560
    return(as.numeric(result))
  }
1561

1562
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1563
  iters <- as.integer(iters)
1564
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1565
  iters <- iters - delta
1566

1567
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1568
}