lgb.Booster.R 43.8 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom R6 R6Class
2
#' @importFrom utils modifyList
James Lamb's avatar
James Lamb committed
3
Booster <- R6::R6Class(
4
  classname = "lgb.Booster",
5
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
6
  public = list(
7

8
    best_iter = -1L,
9
    best_score = NA_real_,
10
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
11
    record_evals = list(),
12

13
14
    # Finalize will free up the handles
    finalize = function() {
15
16
17
18
19
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
20
      return(invisible(NULL))
21
    },
22

23
24
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
25
26
                          train_set = NULL,
                          modelfile = NULL,
27
                          model_str = NULL) {
28

29
      handle <- NULL
30

31
      if (!is.null(train_set)) {
32

33
34
35
36
37
38
39
40
41
42
43
44
        if (!lgb.is.Dataset(train_set)) {
          stop("lgb.Booster: Can only use lgb.Dataset as training data")
        }
        train_set_handle <- train_set$.__enclos_env__$private$get_handle()
        params <- utils::modifyList(params, train_set$get_params())
        params_str <- lgb.params2str(params = params)
        # Store booster handle
        handle <- .Call(
          LGBM_BoosterCreate_R
          , train_set_handle
          , params_str
        )
45

46
47
48
49
50
        # Create private booster information
        private$train_set <- train_set
        private$train_set_version <- train_set$.__enclos_env__$private$version
        private$num_dataset <- 1L
        private$init_predictor <- train_set$.__enclos_env__$private$predictor
51

52
        if (!is.null(private$init_predictor)) {
53

54
55
56
57
58
          # Merge booster
          .Call(
            LGBM_BoosterMerge_R
            , handle
            , private$init_predictor$.__enclos_env__$private$handle
59
          )
60

61
        }
62

63
64
        # Check current iteration
        private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
65

66
      } else if (!is.null(modelfile)) {
67

68
69
70
71
        # Do we have a model file as character?
        if (!is.character(modelfile)) {
          stop("lgb.Booster: Can only use a string as model file path")
        }
72

73
        modelfile <- path.expand(modelfile)
74

75
76
77
78
79
        # Create booster from model
        handle <- .Call(
          LGBM_BoosterCreateFromModelfile_R
          , modelfile
        )
80

81
      } else if (!is.null(model_str)) {
82

83
84
85
86
        # Do we have a model_str as character/raw?
        if (!is.raw(model_str) && !is.character(model_str)) {
          stop("lgb.Booster: Can only use a character/raw vector as model_str")
        }
87

88
89
90
91
92
        # Create booster from model
        handle <- .Call(
          LGBM_BoosterLoadModelFromString_R
          , model_str
        )
93

Guolin Ke's avatar
Guolin Ke committed
94
      } else {
95

96
97
98
99
        # Booster non existent
        stop(
          "lgb.Booster: Need at least either training dataset, "
          , "model file, or model_str to create booster instance"
100
        )
101

Guolin Ke's avatar
Guolin Ke committed
102
      }
103

104
105
106
107
108
109
110
111
112
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
      private$num_class <- 1L
      .Call(
        LGBM_BoosterGetNumClasses_R
        , private$handle
        , private$num_class
      )

113
114
      self$params <- params

115
116
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
117
    },
118

119
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
120
    set_train_data_name = function(name) {
121

122
      # Set name
Guolin Ke's avatar
Guolin Ke committed
123
      private$name_train_set <- name
124
      return(invisible(self))
125

Guolin Ke's avatar
Guolin Ke committed
126
    },
127

128
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
129
    add_valid = function(data, name) {
130

131
      if (!lgb.is.Dataset(data)) {
132
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
133
      }
134

Guolin Ke's avatar
Guolin Ke committed
135
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
136
137
138
139
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
140
      }
141

142
143
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
144
      }
145

146
      # Add validation data to booster
147
148
      .Call(
        LGBM_BoosterAddValidData_R
149
150
151
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
152

153
154
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
155
      private$num_dataset <- private$num_dataset + 1L
156
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
157

158
      return(invisible(self))
159

Guolin Ke's avatar
Guolin Ke committed
160
    },
161

162
    reset_parameter = function(params) {
163

164
      if (methods::is(self$params, "list")) {
165
        params <- utils::modifyList(self$params, params)
166
167
      }

168
      params_str <- lgb.params2str(params = params)
169

170
171
      self$restore_handle()

172
173
      .Call(
        LGBM_BoosterResetParameter_R
174
175
176
        , private$handle
        , params_str
      )
177
      self$params <- params
178

179
      return(invisible(self))
180

Guolin Ke's avatar
Guolin Ke committed
181
    },
182

183
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
184
    update = function(train_set = NULL, fobj = NULL) {
185

186
187
188
189
190
191
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

Guolin Ke's avatar
Guolin Ke committed
192
      if (!is.null(train_set)) {
193

194
        if (!lgb.is.Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
195
196
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
197

Guolin Ke's avatar
Guolin Ke committed
198
        if (!identical(train_set$predictor, private$init_predictor)) {
199
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
200
        }
201

202
203
        .Call(
          LGBM_BoosterResetTrainingData_R
204
205
206
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
207

208
        private$train_set <- train_set
209
        private$train_set_version <- train_set$.__enclos_env__$private$version
210

Guolin Ke's avatar
Guolin Ke committed
211
      }
212

213
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
214
      if (is.null(fobj)) {
215
216
217
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
218
        # Boost iteration from known objective
219
220
        .Call(
          LGBM_BoosterUpdateOneIter_R
221
222
          , private$handle
        )
223

Guolin Ke's avatar
Guolin Ke committed
224
      } else {
225

226
227
228
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
229
        if (!private$set_objective_to_none) {
230
          self$reset_parameter(params = list(objective = "none"))
231
          private$set_objective_to_none <- TRUE
232
        }
233
        # Perform objective calculation
234
235
        preds <- private$inner_predict(1L)
        gpair <- fobj(preds, private$train_set)
236

237
        # Check for gradient and hessian as list
238
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
239
          stop("lgb.Booster.update: custom objective should
240
241
            return a list with attributes (hess, grad)")
        }
242

243
244
245
246
247
248
249
250
251
252
253
        # Check grad and hess have the right shape
        n_grad <- length(gpair$grad)
        n_hess <- length(gpair$hess)
        n_preds <- length(preds)
        if (n_grad != n_preds) {
          stop(sprintf("Expected custom objective function to return grad with length %d, got %d.", n_preds, n_grad))
        }
        if (n_hess != n_preds) {
          stop(sprintf("Expected custom objective function to return hess with length %d, got %d.", n_preds, n_hess))
        }

254
        # Return custom boosting gradient/hessian
255
256
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
257
258
259
          , private$handle
          , gpair$grad
          , gpair$hess
260
          , n_preds
261
        )
262

Guolin Ke's avatar
Guolin Ke committed
263
      }
264

265
      # Loop through each iteration
266
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
267
268
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
269

270
      return(invisible(self))
271

Guolin Ke's avatar
Guolin Ke committed
272
    },
273

274
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
275
    rollback_one_iter = function() {
276

277
278
      self$restore_handle()

279
280
      .Call(
        LGBM_BoosterRollbackOneIter_R
281
282
        , private$handle
      )
283

284
      # Loop through each iteration
285
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
286
287
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
288

289
      return(invisible(self))
290

Guolin Ke's avatar
Guolin Ke committed
291
    },
292

293
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
294
    current_iter = function() {
295

296
297
      self$restore_handle()

298
      cur_iter <- 0L
299
300
301
302
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
303
      )
304
      return(cur_iter)
305

Guolin Ke's avatar
Guolin Ke committed
306
    },
307

308
    # Get upper bound
309
    upper_bound = function() {
310

311
312
      self$restore_handle()

313
      upper_bound <- 0.0
314
315
316
317
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
318
      )
319
      return(upper_bound)
320
321
322
323

    },

    # Get lower bound
324
    lower_bound = function() {
325

326
327
      self$restore_handle()

328
      lower_bound <- 0.0
329
330
331
332
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
333
      )
334
      return(lower_bound)
335
336
337

    },

338
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
339
    eval = function(data, name, feval = NULL) {
340

341
      if (!lgb.is.Dataset(data)) {
342
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
343
      }
344

345
      # Check for identical data
346
      data_idx <- 0L
347
      if (identical(data, private$train_set)) {
348
        data_idx <- 1L
349
      } else {
350

351
        # Check for validation data
352
        if (length(private$valid_sets) > 0L) {
353

354
          for (i in seq_along(private$valid_sets)) {
355

356
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
357
            if (identical(data, private$valid_sets[[i]])) {
358

359
              # Found identical data, skip
360
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
361
              break
362

Guolin Ke's avatar
Guolin Ke committed
363
            }
364

Guolin Ke's avatar
Guolin Ke committed
365
          }
366

Guolin Ke's avatar
Guolin Ke committed
367
        }
368

Guolin Ke's avatar
Guolin Ke committed
369
      }
370

371
      # Check if evaluation was not done
372
      if (data_idx == 0L) {
373

374
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
375
376
        self$add_valid(data, name)
        data_idx <- private$num_dataset
377

Guolin Ke's avatar
Guolin Ke committed
378
      }
379

380
      # Evaluate data
381
382
383
384
385
386
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
387
      )
388

Guolin Ke's avatar
Guolin Ke committed
389
    },
390

391
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
392
    eval_train = function(feval = NULL) {
393
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
394
    },
395

396
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
397
    eval_valid = function(feval = NULL) {
398

399
      ret <- list()
400

401
      if (length(private$valid_sets) <= 0L) {
402
403
        return(ret)
      }
404

405
      for (i in seq_along(private$valid_sets)) {
406
407
        ret <- append(
          x = ret
408
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
409
        )
Guolin Ke's avatar
Guolin Ke committed
410
      }
411

412
      return(ret)
413

Guolin Ke's avatar
Guolin Ke committed
414
    },
415

416
    # Save model
417
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
418

419
420
      self$restore_handle()

421
422
423
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
424

425
426
      filename <- path.expand(filename)

427
428
      .Call(
        LGBM_BoosterSaveModel_R
429
430
        , private$handle
        , as.integer(num_iteration)
431
        , as.integer(feature_importance_type)
432
        , filename
433
      )
434

435
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
436
    },
437

438
439
440
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {

      self$restore_handle()
441

442
443
444
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
445

446
      model_str <- .Call(
447
          LGBM_BoosterSaveModelToString_R
448
449
450
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
451
452
      )

453
454
455
456
      if (as_char) {
        model_str <- rawToChar(model_str)
      }

457
      return(model_str)
458

459
    },
460

461
    # Dump model in memory
462
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
463

464
465
      self$restore_handle()

466
467
468
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
469

470
      model_str <- .Call(
471
472
473
474
475
476
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
      )

477
      return(model_str)
478

Guolin Ke's avatar
Guolin Ke committed
479
    },
480

481
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
482
    predict = function(data,
483
                       start_iteration = NULL,
484
485
486
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
487
                       predcontrib = FALSE,
488
                       header = FALSE,
489
                       params = list()) {
490

491
492
      self$restore_handle()

493
494
495
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
496

497
498
499
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
      # possibly override keyword arguments with parameters
      #
      # NOTE: this length() check minimizes the latency introduced by these checks,
      #       for the common case where params is empty
      #
      # NOTE: doing this here instead of in Predictor$predict() to keep
      #       Predictor$predict() as fast as possible
      if (length(params) > 0L) {
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_raw_score"
          , params = params
          , alternative_kwarg_value = rawscore
        )
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_leaf_index"
          , params = params
          , alternative_kwarg_value = predleaf
        )
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_contrib"
          , params = params
          , alternative_kwarg_value = predcontrib
        )
        rawscore <- params[["predict_raw_score"]]
        predleaf <- params[["predict_leaf_index"]]
        predcontrib <- params[["predict_contrib"]]
      }

529
      # Predict on new data
530
531
532
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
533
        , fast_predict_config = private$fast_predict_config
534
      )
535
536
      return(
        predictor$predict(
537
538
539
540
541
542
543
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
544
        )
545
      )
546

547
    },
548

549
550
    # Transform into predictor
    to_predictor = function() {
551
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
552
    },
553

554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
    configure_fast_predict = function(csr = FALSE,
                                      start_iteration = NULL,
                                      num_iteration = NULL,
                                      rawscore = FALSE,
                                      predleaf = FALSE,
                                      predcontrib = FALSE,
                                      params = list()) {

      self$restore_handle()
      ncols <- .Call(LGBM_BoosterGetNumFeature_R, private$handle)

      if (is.null(num_iteration)) {
        num_iteration <- -1L
      }
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }

      if (!csr) {
        fun <- LGBM_BoosterPredictForMatSingleRowFastInit_R
      } else {
        fun <- LGBM_BoosterPredictForCSRSingleRowFastInit_R
      }

      fast_handle <- .Call(
        fun
        , private$handle
        , ncols
        , rawscore
        , predleaf
        , predcontrib
        , start_iteration
        , num_iteration
        , lgb.params2str(params = params)
      )

      private$fast_predict_config <- list(
        handle = fast_handle
        , csr = as.logical(csr)
        , ncols = ncols
        , start_iteration = start_iteration
        , num_iteration = num_iteration
        , rawscore = as.logical(rawscore)
        , predleaf = as.logical(predleaf)
        , predcontrib = as.logical(predcontrib)
        , params = params
      )

      return(invisible(NULL))
    },

605
606
    # Used for serialization
    raw = NULL,
607

608
609
610
611
612
613
    # Store serialized raw bytes in model object
    save_raw = function() {
      if (is.null(self$raw)) {
        self$raw <- self$save_model_to_string(NULL, as_char = FALSE)
      }
      return(invisible(NULL))
614

615
    },
616

617
618
    drop_raw = function() {
      self$raw <- NULL
619
      return(invisible(NULL))
620
    },
621

622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
    check_null_handle = function() {
      return(lgb.is.null.handle(private$handle))
    },

    restore_handle = function() {
      if (self$check_null_handle()) {
        if (is.null(self$raw)) {
          .Call(LGBM_NullBoosterHandleError_R)
        }
        private$handle <- .Call(LGBM_BoosterLoadModelFromString_R, self$raw)
      }
      return(invisible(NULL))
    },

    get_handle = function() {
      return(private$handle)
638
    }
639

Guolin Ke's avatar
Guolin Ke committed
640
641
  ),
  private = list(
642
643
644
645
646
647
648
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
649
650
    num_class = 1L,
    num_dataset = 0L,
651
652
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
653
    higher_better_inner_eval = NULL,
654
    set_objective_to_none = FALSE,
655
    train_set_version = 0L,
656
    fast_predict_config = list(),
657
658
    # Predict data
    inner_predict = function(idx) {
659

660
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
661
      data_name <- private$name_train_set
662

663
664
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
665
      }
666

667
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
668
669
670
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
671

672
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
673
      if (is.null(private$predict_buffer[[data_name]])) {
674

675
        # Store predictions
676
        npred <- 0L
677
678
        .Call(
          LGBM_BoosterGetNumPredict_R
679
          , private$handle
680
          , as.integer(idx - 1L)
681
          , npred
682
        )
683
        private$predict_buffer[[data_name]] <- numeric(npred)
684

Guolin Ke's avatar
Guolin Ke committed
685
      }
686

687
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
688
      if (!private$is_predicted_cur_iter[[idx]]) {
689

690
        # Use buffer
691
692
        .Call(
          LGBM_BoosterGetPredict_R
693
          , private$handle
694
          , as.integer(idx - 1L)
695
          , private$predict_buffer[[data_name]]
696
        )
Guolin Ke's avatar
Guolin Ke committed
697
698
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
699

700
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
701
    },
702

703
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
704
    get_eval_info = function() {
705

Guolin Ke's avatar
Guolin Ke committed
706
      if (is.null(private$eval_names)) {
707
        eval_names <- .Call(
708
          LGBM_BoosterGetEvalNames_R
709
710
          , private$handle
        )
711

712
        if (length(eval_names) > 0L) {
713

714
          # Parse and store privately names
715
          private$eval_names <- eval_names
716
717
718

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
719
          metric_names <- gsub("@.*", "", eval_names)
720
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
721

Guolin Ke's avatar
Guolin Ke committed
722
        }
723

Guolin Ke's avatar
Guolin Ke committed
724
      }
725

726
      return(private$eval_names)
727

Guolin Ke's avatar
Guolin Ke committed
728
    },
729

Guolin Ke's avatar
Guolin Ke committed
730
    inner_eval = function(data_name, data_idx, feval = NULL) {
731

732
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
733
734
735
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
736

737
738
      self$restore_handle()

Guolin Ke's avatar
Guolin Ke committed
739
      private$get_eval_info()
740

Guolin Ke's avatar
Guolin Ke committed
741
      ret <- list()
742

743
      if (length(private$eval_names) > 0L) {
744

745
746
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
747
748
        .Call(
          LGBM_BoosterGetEval_R
749
          , private$handle
750
          , as.integer(data_idx - 1L)
751
          , tmp_vals
752
        )
753

754
        for (i in seq_along(private$eval_names)) {
755

756
757
758
759
760
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
761
          res$higher_better <- private$higher_better_inner_eval[i]
762
          ret <- append(ret, list(res))
763

Guolin Ke's avatar
Guolin Ke committed
764
        }
765

Guolin Ke's avatar
Guolin Ke committed
766
      }
767

768
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
769
      if (!is.null(feval)) {
770

771
        # Check if evaluation metric is a function
772
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
773
774
          stop("lgb.Booster.eval: feval should be a function")
        }
775

Guolin Ke's avatar
Guolin Ke committed
776
        data <- private$train_set
777

778
        # Check if data to assess is existing differently
779
780
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
781
        }
782

783
        # Perform function evaluation
784
        res <- feval(private$inner_predict(data_idx), data)
785

786
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
787
788
789
          stop(
            "lgb.Booster.eval: custom eval function should return a list with attribute (name, value, higher_better)"
          )
790
        }
791

792
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
793
        res$data_name <- data_name
794
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
795
      }
796

797
      return(ret)
798

Guolin Ke's avatar
Guolin Ke committed
799
    }
800

Guolin Ke's avatar
Guolin Ke committed
801
802
803
  )
)

804
#' @name lgb_predict_shared_params
805
806
807
808
809
#' @param type Type of prediction to output. Allowed types are:\itemize{
#'             \item \code{"response"}: will output the predicted score according to the objective function being
#'                   optimized (depending on the link function that the objective uses), after applying any necessary
#'                   transformations - for example, for \code{objective="binary"}, it will output class probabilities.
#'             \item \code{"class"}: for classification objectives, will output the class with the highest predicted
810
811
812
#'                   probability. For other objectives, will output the same as "response". Note that \code{"class"} is
#'                   not a supported type for \link{lgb.configure_fast_predict} (see the documentation of that function
#'                   for more details).
813
814
815
816
817
818
819
#'             \item \code{"raw"}: will output the non-transformed numbers (sum of predictions from boosting iterations'
#'                   results) from which the "response" number is produced for a given objective function - for example,
#'                   for \code{objective="binary"}, this corresponds to log-odds. For many objectives such as
#'                   "regression", since no transformation is applied, the output will be the same as for "response".
#'             \item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls
#'                   in each tree in the model, outputted as integers, with one column per tree.
#'             \item \code{"contrib"}: will return the per-feature contributions for each prediction, including an
820
#'                   intercept (each feature will produce one column).
821
822
823
824
#'             }
#'
#'             Note that, if using custom objectives, types "class" and "response" will not be available and will
#'             default towards using "raw" instead.
825
826
827
828
829
830
831
832
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
833
834
835
#' @param params a list of additional named parameters. See
#'               \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#'               the "Predict Parameters" section of the documentation} for a list of parameters and
836
837
#'               valid values. Where these conflict with the values of keyword arguments to this function,
#'               the values in \code{params} take precedence.
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
NULL

#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
#' @details If the model object has been configured for fast single-row predictions through
#'          \link{lgb.configure_fast_predict}, this function will use the prediction parameters
#'          that were configured for it - as such, extra prediction parameters should not be passed
#'          here, otherwise the configuration will be ignored and the slow route will be taken.
#' @inheritParams lgb_predict_shared_params
#' @param object Object of class \code{lgb.Booster}
#' @param newdata a \code{matrix} object, a \code{dgCMatrix}, a \code{dgRMatrix} object, a \code{dsparseVector} object,
#'                or a character representing a path to a text file (CSV, TSV, or LibSVM).
#'
#'                For sparse inputs, if predictions are only going to be made for a single row, it will be faster to
#'                use CSR format, in which case the data may be passed as either a single-row CSR matrix (class
#'                \code{dgRMatrix} from package \code{Matrix}) or as a sparse numeric vector (class
#'                \code{dsparseVector} from package \code{Matrix}).
#'
#'                If single-row predictions are going to be performed frequently, it is recommended to
#'                pre-configure the model object for fast single-row sparse predictions through function
#'                \link{lgb.configure_fast_predict}.
#' @param header only used for prediction for text file. True if text file has header
861
#' @param ... ignored
862
#' @return For prediction types that are meant to always return one output per observation (e.g. when predicting
863
864
#'         \code{type="response"} or \code{type="raw"} on a binary classification or regression objective), will
#'         return a vector with one element per row in \code{newdata}.
865
#'
866
#'         For prediction types that are meant to return more than one output per observation (e.g. when predicting
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
#'         \code{type="response"} or \code{type="raw"} on a multi-class objective, or when predicting
#'         \code{type="leaf"}, regardless of objective), will return a matrix with one row per observation in
#'         \code{newdata} and one column per output.
#'
#'         For \code{type="leaf"} predictions, will return a matrix with one row per observation in \code{newdata}
#'         and one column per tree. Note that for multiclass objectives, LightGBM trains one tree per class at each
#'         boosting iteration. That means that, for example, for a multiclass model with 3 classes, the leaf
#'         predictions for the first class can be found in columns 1, 4, 7, 10, etc.
#'
#'         For \code{type="contrib"}, will return a matrix of SHAP values with one row per observation in
#'         \code{newdata} and columns corresponding to features. For regression, ranking, cross-entropy, and binary
#'         classification objectives, this matrix contains one column per feature plus a final column containing the
#'         Shapley base value. For multiclass objectives, this matrix will represent \code{num_classes} such matrices,
#'         in the order "feature contributions for first class, feature contributions for second class, feature
#'         contributions for third class, etc.".
882
#'
Guolin Ke's avatar
Guolin Ke committed
883
#' @examples
884
#' \donttest{
885
886
887
888
889
890
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
891
892
893
894
895
896
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
897
#' valids <- list(test = dtest)
898
899
900
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
901
#'   , nrounds = 5L
902
903
#'   , valids = valids
#' )
904
#' preds <- predict(model, test$data)
905
906
#'
#' # pass other prediction parameters
907
#' preds <- predict(
908
909
910
911
912
913
#'     model,
#'     test$data,
#'     params = list(
#'         predict_disable_shape_check = TRUE
#'    )
#' )
914
#' }
915
#' @importFrom utils modifyList
Guolin Ke's avatar
Guolin Ke committed
916
#' @export
James Lamb's avatar
James Lamb committed
917
predict.lgb.Booster <- function(object,
918
                                newdata,
919
                                type = "response",
920
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
921
922
                                num_iteration = NULL,
                                header = FALSE,
923
                                params = list(),
James Lamb's avatar
James Lamb committed
924
                                ...) {
925

926
  if (!lgb.is.Booster(x = object)) {
927
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
928
  }
929

930
931
  additional_params <- list(...)
  if (length(additional_params) > 0L) {
932
933
    additional_params_names <- names(additional_params)
    if ("reshape" %in% additional_params_names) {
934
935
      stop("'reshape' argument is no longer supported.")
    }
936
937
938
939
940
941
942
943
944
945
946
947
948
949

    old_args_for_type <- list(
      "rawscore" = "raw"
      , "predleaf" = "leaf"
      , "predcontrib" = "contrib"
    )
    for (arg in names(old_args_for_type)) {
      if (arg %in% additional_params_names) {
        stop(sprintf("Argument '%s' is no longer supported. Use type='%s' instead."
                     , arg
                     , old_args_for_type[[arg]]))
      }
    }

950
951
    warning(paste0(
      "predict.lgb.Booster: Found the following passed through '...': "
952
      , toString(names(additional_params))
953
      , ". These are ignored. Use argument 'params' instead."
954
955
956
    ))
  }

957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
  if (!is.null(object$params$objective) && object$params$objective == "none" && type %in% c("class", "response")) {
    warning("Prediction types 'class' and 'response' are not supported for custom objectives.")
    type <- "raw"
  }

  rawscore <- FALSE
  predleaf <- FALSE
  predcontrib <- FALSE
  if (type == "raw") {
    rawscore <- TRUE
  } else if (type == "leaf") {
    predleaf <- TRUE
  } else if (type == "contrib") {
    predcontrib <- TRUE
  }

  pred <- object$predict(
    data = newdata
    , start_iteration = start_iteration
    , num_iteration = num_iteration
    , rawscore = rawscore
    , predleaf =  predleaf
    , predcontrib =  predcontrib
    , header = header
    , params = params
982
  )
983
984
985
986
987
988
989
990
  if (type == "class") {
    if (object$params$objective == "binary") {
      pred <- as.integer(pred >= 0.5)
    } else if (object$params$objective %in% c("multiclass", "multiclassova")) {
      pred <- max.col(pred) - 1L
    }
  }
  return(pred)
Guolin Ke's avatar
Guolin Ke committed
991
992
}

993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
#' @title Configure Fast Single-Row Predictions
#' @description Pre-configures a LightGBM model object to produce fast single-row predictions
#'              for a given input data type, prediction type, and parameters.
#' @details Calling this function multiple times with different parameters might not override
#'          the previous configuration and might trigger undefined behavior.
#'
#'          Any saved configuration for fast predictions might be lost after making a single-row
#'          prediction of a different type than what was configured (except for types "response" and
#'          "class", which can be switched between each other at any time without losing the configuration).
#'
#'          In some situations, setting a fast prediction configuration for one type of prediction
#'          might cause the prediction function to keep using that configuration for single-row
#'          predictions even if the requested type of prediction is different from what was configured.
#'
#'          Note that this function will not accept argument \code{type="class"} - for such cases, one
#'          can pass \code{type="response"} to this function and then \code{type="class"} to the
#'          \code{predict} function - the fast configuration will not be lost or altered if the switch
#'          is between "response" and "class".
#'
#'          The configuration does not survive de-serializations, so it has to be generated
#'          anew in every R process that is going to use it (e.g. if loading a model object
#'          through \code{readRDS}, whatever configuration was there previously will be lost).
#'
#'          Requesting a different prediction type or passing parameters to \link{predict.lgb.Booster}
#'          will cause it to ignore the fast-predict configuration and take the slow route instead
#'          (but be aware that an existing configuration might not always be overriden by supplying
#'          different parameters or prediction type, so make sure to check that the output is what
#'          was expected when a prediction is to be made on a single row for something different than
#'          what is configured).
#'
#'          Note that, if configuring a non-default prediction type (such as leaf indices),
#'          then that type must also be passed in the call to \link{predict.lgb.Booster} in
#'          order for it to use the configuration. This also applies for \code{start_iteration}
#'          and \code{num_iteration}, but \bold{the \code{params} list must be empty} in the call to \code{predict}.
#'
#'          Predictions about feature contributions do not allow a fast route for CSR inputs,
#'          and as such, this function will produce an error if passing \code{csr=TRUE} and
#'          \code{type = "contrib"} together.
#' @inheritParams lgb_predict_shared_params
#' @param model LighGBM model object (class \code{lgb.Booster}).
#'
#'              \bold{The object will be modified in-place}.
#' @param csr Whether the prediction function is going to be called on sparse CSR inputs.
#'            If \code{FALSE}, will be assumed that predictions are going to be called on single-row
#'            regular R matrices.
#' @return The same \code{model} that was passed as input, invisibly, with the desired
#'         configuration stored inside it and available to be used in future calls to
#'         \link{predict.lgb.Booster}.
#' @examples
#' \donttest{
#' library(lightgbm)
#' data(mtcars)
#' X <- as.matrix(mtcars[, -1L])
#' y <- mtcars[, 1L]
#' dtrain <- lgb.Dataset(X, label = y, params = list(max_bin = 5L))
#' params <- list(min_data_in_leaf = 2L)
#' model <- lgb.train(
#'   params = params
#'  , data = dtrain
#'  , obj = "regression"
#'  , nrounds = 5L
#'  , verbose = -1L
#' )
#' lgb.configure_fast_predict(model)
#'
#' x_single <- X[11L, , drop = FALSE]
#' predict(model, x_single)
#'
#' # Will not use it if the prediction to be made
#' # is different from what was configured
#' predict(model, x_single, type = "leaf")
#' }
#' @export
lgb.configure_fast_predict <- function(model,
                                       csr = FALSE,
                                       start_iteration = NULL,
                                       num_iteration = NULL,
                                       type = "response",
                                       params = list()) {
  if (!lgb.is.Booster(x = model)) {
    stop("lgb.configure_fast_predict: model should be an ", sQuote("lgb.Booster"))
  }
  if (type == "class") {
    stop("type='class' is not supported for 'lgb.configure_fast_predict'. Use 'response' instead.")
  }

  rawscore <- FALSE
  predleaf <- FALSE
  predcontrib <- FALSE
  if (type == "raw") {
    rawscore <- TRUE
  } else if (type == "leaf") {
    predleaf <- TRUE
  } else if (type == "contrib") {
    predcontrib <- TRUE
  }

  if (csr && predcontrib) {
    stop("'lgb.configure_fast_predict' does not support feature contributions for CSR data.")
  }
  model$configure_fast_predict(
    csr = csr
    , start_iteration = start_iteration
    , num_iteration = num_iteration
    , rawscore = rawscore
    , predleaf = predleaf
    , predcontrib = predcontrib
    , params = params
  )
  return(invisible(model))
}

1105
1106
1107
1108
1109
#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
1110
#' @return The same input \code{x}, returned as invisible.
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
#' @export
print.lgb.Booster <- function(x, ...) {
  # nolint start
  handle <- x$.__enclos_env__$private$handle
  handle_is_null <- lgb.is.null.handle(handle)

  if (!handle_is_null) {
    ntrees <- x$current_iter()
    if (ntrees == 1L) {
      cat("LightGBM Model (1 tree)\n")
    } else {
      cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
    }
  } else {
    cat("LightGBM Model\n")
  }

  if (!handle_is_null) {
    obj <- x$params$objective
    if (obj == "none") {
      obj <- "custom"
    }
1133
1134
    num_class <- x$.__enclos_env__$private$num_class
    if (num_class == 1L) {
1135
1136
1137
1138
      cat(sprintf("Objective: %s\n", obj))
    } else {
      cat(sprintf("Objective: %s (%d classes)\n"
          , obj
1139
          , num_class))
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
    }
  } else {
    cat("(Booster handle is invalid)\n")
  }

  if (!handle_is_null) {
    ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
    cat(sprintf("Fitted to dataset with %d columns\n", ncols))
  }
  # nolint end

  return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
1159
#' @return The same input \code{object}, returned as invisible.
1160
1161
1162
1163
1164
#' @export
summary.lgb.Booster <- function(object, ...) {
  print(object)
}

1165
1166
#' @name lgb.load
#' @title Load LightGBM model
1167
1168
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
1169
#' @param filename path of model file
1170
#' @param model_str a str containing the model (as a \code{character} or \code{raw} vector)
1171
#'
1172
#' @return lgb.Booster
1173
#'
Guolin Ke's avatar
Guolin Ke committed
1174
#' @examples
1175
#' \donttest{
1176
1177
1178
1179
1180
1181
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1182
1183
1184
1185
1186
1187
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1188
#' valids <- list(test = dtest)
1189
1190
1191
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1192
#'   , nrounds = 5L
1193
#'   , valids = valids
1194
#'   , early_stopping_rounds = 3L
1195
#' )
1196
1197
1198
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
1199
1200
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
1201
#' }
Guolin Ke's avatar
Guolin Ke committed
1202
#' @export
1203
lgb.load <- function(filename = NULL, model_str = NULL) {
1204

1205
1206
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
1207

1208
1209
1210
1211
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
1212
    filename <- path.expand(filename)
1213
1214
1215
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
1216
1217
    return(invisible(Booster$new(modelfile = filename)))
  }
1218

1219
  if (model_str_provided) {
1220
1221
    if (!is.raw(model_str) && !is.character(model_str)) {
      stop("lgb.load: model_str should be a character/raw vector")
1222
    }
1223
1224
    return(invisible(Booster$new(model_str = model_str)))
  }
1225

1226
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
1227
1228
}

1229
1230
1231
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
1232
1233
1234
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1235
#'
1236
#' @return lgb.Booster
1237
#'
Guolin Ke's avatar
Guolin Ke committed
1238
#' @examples
1239
#' \donttest{
1240
1241
1242
1243
1244
1245
1246
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1247
1248
1249
1250
1251
1252
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1253
#' valids <- list(test = dtest)
1254
1255
1256
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1257
#'   , nrounds = 10L
1258
#'   , valids = valids
1259
#'   , early_stopping_rounds = 5L
1260
#' )
1261
#' lgb.save(model, tempfile(fileext = ".txt"))
1262
#' }
Guolin Ke's avatar
Guolin Ke committed
1263
#' @export
1264
lgb.save <- function(booster, filename, num_iteration = NULL) {
1265

1266
  if (!lgb.is.Booster(x = booster)) {
1267
1268
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1269

1270
1271
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
1272
  }
1273
  filename <- path.expand(filename)
1274

1275
  # Store booster
1276
1277
1278
1279
1280
1281
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
1282

Guolin Ke's avatar
Guolin Ke committed
1283
1284
}

1285
1286
1287
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
1288
1289
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1290
#'
Guolin Ke's avatar
Guolin Ke committed
1291
#' @return json format of model
1292
#'
Guolin Ke's avatar
Guolin Ke committed
1293
#' @examples
1294
#' \donttest{
1295
1296
1297
1298
1299
1300
1301
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1302
1303
1304
1305
1306
1307
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1308
#' valids <- list(test = dtest)
1309
1310
1311
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1312
#'   , nrounds = 10L
1313
#'   , valids = valids
1314
#'   , early_stopping_rounds = 5L
1315
#' )
1316
#' json_model <- lgb.dump(model)
1317
#' }
Guolin Ke's avatar
Guolin Ke committed
1318
#' @export
1319
lgb.dump <- function(booster, num_iteration = NULL) {
1320

1321
  if (!lgb.is.Booster(x = booster)) {
1322
    stop("lgb.dump: booster should be an ", sQuote("lgb.Booster"))
1323
  }
1324

1325
  # Return booster at requested iteration
1326
  return(booster$dump_model(num_iteration =  num_iteration))
1327

Guolin Ke's avatar
Guolin Ke committed
1328
1329
}

1330
1331
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
1332
1333
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
1334
#' @param booster Object of class \code{lgb.Booster}
1335
1336
1337
1338
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
1339
#' @param is_err TRUE will return evaluation error instead
1340
#'
1341
#' @return numeric vector of evaluation result
1342
#'
1343
#' @examples
1344
#' \donttest{
1345
#' # train a regression model
1346
1347
1348
1349
1350
1351
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1352
1353
1354
1355
1356
1357
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1358
#' valids <- list(test = dtest)
1359
1360
1361
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1362
#'   , nrounds = 5L
1363
1364
#'   , valids = valids
#' )
1365
1366
1367
1368
1369
1370
1371
1372
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1373
#' lgb.get.eval.result(model, "test", "l2")
1374
#' }
Guolin Ke's avatar
Guolin Ke committed
1375
#' @export
1376
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1377

1378
  if (!lgb.is.Booster(x = booster)) {
1379
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1380
  }
1381

1382
1383
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1384
  }
1385

1386
1387
1388
1389
1390
1391
1392
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
1393
      , toString(data_names)
1394
1395
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1396
  }
1397

1398
  # Check if evaluation result is existing
1399
1400
1401
1402
1403
1404
1405
1406
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
1407
      , toString(eval_names)
1408
1409
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1410
1411
    stop("lgb.get.eval.result: wrong eval name")
  }
1412

1413
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1414

1415
  # Check if error is requested
1416
  if (is_err) {
1417
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1418
  }
1419

1420
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1421
1422
    return(as.numeric(result))
  }
1423

1424
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1425
  iters <- as.integer(iters)
1426
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1427
  iters <- iters - delta
1428

1429
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1430
}