lgb.Booster.R 44.2 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom R6 R6Class
2
#' @importFrom utils modifyList
James Lamb's avatar
James Lamb committed
3
Booster <- R6::R6Class(
4
  classname = "lgb.Booster",
5
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
6
  public = list(
7

8
    best_iter = -1L,
9
    best_score = NA_real_,
10
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
11
    record_evals = list(),
12

13
14
    # Finalize will free up the handles
    finalize = function() {
15
16
17
18
19
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
20
      return(invisible(NULL))
21
    },
22

23
24
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
25
26
                          train_set = NULL,
                          modelfile = NULL,
27
                          model_str = NULL) {
28

29
      handle <- NULL
30

31
      if (!is.null(train_set)) {
32

33
34
35
36
37
38
39
40
41
42
43
44
        if (!lgb.is.Dataset(train_set)) {
          stop("lgb.Booster: Can only use lgb.Dataset as training data")
        }
        train_set_handle <- train_set$.__enclos_env__$private$get_handle()
        params <- utils::modifyList(params, train_set$get_params())
        params_str <- lgb.params2str(params = params)
        # Store booster handle
        handle <- .Call(
          LGBM_BoosterCreate_R
          , train_set_handle
          , params_str
        )
45

46
47
48
49
50
        # Create private booster information
        private$train_set <- train_set
        private$train_set_version <- train_set$.__enclos_env__$private$version
        private$num_dataset <- 1L
        private$init_predictor <- train_set$.__enclos_env__$private$predictor
51

52
        if (!is.null(private$init_predictor)) {
53

54
55
56
57
58
          # Merge booster
          .Call(
            LGBM_BoosterMerge_R
            , handle
            , private$init_predictor$.__enclos_env__$private$handle
59
          )
60

61
        }
62

63
64
        # Check current iteration
        private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
65

66
      } else if (!is.null(modelfile)) {
67

68
69
70
71
        # Do we have a model file as character?
        if (!is.character(modelfile)) {
          stop("lgb.Booster: Can only use a string as model file path")
        }
72

73
        modelfile <- path.expand(modelfile)
74

75
76
77
78
79
        # Create booster from model
        handle <- .Call(
          LGBM_BoosterCreateFromModelfile_R
          , modelfile
        )
80
        params <- private$get_loaded_param(handle)
81

82
      } else if (!is.null(model_str)) {
83

84
85
86
87
        # Do we have a model_str as character/raw?
        if (!is.raw(model_str) && !is.character(model_str)) {
          stop("lgb.Booster: Can only use a character/raw vector as model_str")
        }
88

89
90
91
92
93
        # Create booster from model
        handle <- .Call(
          LGBM_BoosterLoadModelFromString_R
          , model_str
        )
94

Guolin Ke's avatar
Guolin Ke committed
95
      } else {
96

97
98
99
100
        # Booster non existent
        stop(
          "lgb.Booster: Need at least either training dataset, "
          , "model file, or model_str to create booster instance"
101
        )
102

Guolin Ke's avatar
Guolin Ke committed
103
      }
104

105
106
107
108
109
110
111
112
113
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
      private$num_class <- 1L
      .Call(
        LGBM_BoosterGetNumClasses_R
        , private$handle
        , private$num_class
      )

114
115
      self$params <- params

116
117
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
118
    },
119

120
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
121
    set_train_data_name = function(name) {
122

123
      # Set name
Guolin Ke's avatar
Guolin Ke committed
124
      private$name_train_set <- name
125
      return(invisible(self))
126

Guolin Ke's avatar
Guolin Ke committed
127
    },
128

129
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
130
    add_valid = function(data, name) {
131

132
      if (!lgb.is.Dataset(data)) {
133
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
134
      }
135

Guolin Ke's avatar
Guolin Ke committed
136
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
137
138
139
140
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
141
      }
142

143
144
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
145
      }
146

147
      # Add validation data to booster
148
149
      .Call(
        LGBM_BoosterAddValidData_R
150
151
152
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
153

154
155
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
156
      private$num_dataset <- private$num_dataset + 1L
157
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
158

159
      return(invisible(self))
160

Guolin Ke's avatar
Guolin Ke committed
161
    },
162

163
    reset_parameter = function(params) {
164

165
      if (methods::is(self$params, "list")) {
166
        params <- utils::modifyList(self$params, params)
167
168
      }

169
      params_str <- lgb.params2str(params = params)
170

171
172
      self$restore_handle()

173
174
      .Call(
        LGBM_BoosterResetParameter_R
175
176
177
        , private$handle
        , params_str
      )
178
      self$params <- params
179

180
      return(invisible(self))
181

Guolin Ke's avatar
Guolin Ke committed
182
    },
183

184
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
185
    update = function(train_set = NULL, fobj = NULL) {
186

187
188
189
190
191
192
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

Guolin Ke's avatar
Guolin Ke committed
193
      if (!is.null(train_set)) {
194

195
        if (!lgb.is.Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
196
197
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
198

Guolin Ke's avatar
Guolin Ke committed
199
        if (!identical(train_set$predictor, private$init_predictor)) {
200
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
201
        }
202

203
204
        .Call(
          LGBM_BoosterResetTrainingData_R
205
206
207
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
208

209
        private$train_set <- train_set
210
        private$train_set_version <- train_set$.__enclos_env__$private$version
211

Guolin Ke's avatar
Guolin Ke committed
212
      }
213

214
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
215
      if (is.null(fobj)) {
216
217
218
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
219
        # Boost iteration from known objective
220
221
        .Call(
          LGBM_BoosterUpdateOneIter_R
222
223
          , private$handle
        )
224

Guolin Ke's avatar
Guolin Ke committed
225
      } else {
226

227
228
229
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
230
        if (!private$set_objective_to_none) {
231
          self$reset_parameter(params = list(objective = "none"))
232
          private$set_objective_to_none <- TRUE
233
        }
234
        # Perform objective calculation
235
236
        preds <- private$inner_predict(1L)
        gpair <- fobj(preds, private$train_set)
237

238
        # Check for gradient and hessian as list
239
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
240
          stop("lgb.Booster.update: custom objective should
241
242
            return a list with attributes (hess, grad)")
        }
243

244
245
246
247
248
249
250
251
252
253
254
        # Check grad and hess have the right shape
        n_grad <- length(gpair$grad)
        n_hess <- length(gpair$hess)
        n_preds <- length(preds)
        if (n_grad != n_preds) {
          stop(sprintf("Expected custom objective function to return grad with length %d, got %d.", n_preds, n_grad))
        }
        if (n_hess != n_preds) {
          stop(sprintf("Expected custom objective function to return hess with length %d, got %d.", n_preds, n_hess))
        }

255
        # Return custom boosting gradient/hessian
256
257
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
258
259
260
          , private$handle
          , gpair$grad
          , gpair$hess
261
          , n_preds
262
        )
263

Guolin Ke's avatar
Guolin Ke committed
264
      }
265

266
      # Loop through each iteration
267
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
268
269
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
270

271
      return(invisible(self))
272

Guolin Ke's avatar
Guolin Ke committed
273
    },
274

275
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
276
    rollback_one_iter = function() {
277

278
279
      self$restore_handle()

280
281
      .Call(
        LGBM_BoosterRollbackOneIter_R
282
283
        , private$handle
      )
284

285
      # Loop through each iteration
286
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
287
288
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
289

290
      return(invisible(self))
291

Guolin Ke's avatar
Guolin Ke committed
292
    },
293

294
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
295
    current_iter = function() {
296

297
298
      self$restore_handle()

299
      cur_iter <- 0L
300
301
302
303
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
304
      )
305
      return(cur_iter)
306

Guolin Ke's avatar
Guolin Ke committed
307
    },
308

309
    # Get upper bound
310
    upper_bound = function() {
311

312
313
      self$restore_handle()

314
      upper_bound <- 0.0
315
316
317
318
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
319
      )
320
      return(upper_bound)
321
322
323
324

    },

    # Get lower bound
325
    lower_bound = function() {
326

327
328
      self$restore_handle()

329
      lower_bound <- 0.0
330
331
332
333
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
334
      )
335
      return(lower_bound)
336
337
338

    },

339
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
340
    eval = function(data, name, feval = NULL) {
341

342
      if (!lgb.is.Dataset(data)) {
343
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
344
      }
345

346
      # Check for identical data
347
      data_idx <- 0L
348
      if (identical(data, private$train_set)) {
349
        data_idx <- 1L
350
      } else {
351

352
        # Check for validation data
353
        if (length(private$valid_sets) > 0L) {
354

355
          for (i in seq_along(private$valid_sets)) {
356

357
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
358
            if (identical(data, private$valid_sets[[i]])) {
359

360
              # Found identical data, skip
361
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
362
              break
363

Guolin Ke's avatar
Guolin Ke committed
364
            }
365

Guolin Ke's avatar
Guolin Ke committed
366
          }
367

Guolin Ke's avatar
Guolin Ke committed
368
        }
369

Guolin Ke's avatar
Guolin Ke committed
370
      }
371

372
      # Check if evaluation was not done
373
      if (data_idx == 0L) {
374

375
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
376
377
        self$add_valid(data, name)
        data_idx <- private$num_dataset
378

Guolin Ke's avatar
Guolin Ke committed
379
      }
380

381
      # Evaluate data
382
383
384
385
386
387
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
388
      )
389

Guolin Ke's avatar
Guolin Ke committed
390
    },
391

392
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
393
    eval_train = function(feval = NULL) {
394
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
395
    },
396

397
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
398
    eval_valid = function(feval = NULL) {
399

400
      ret <- list()
401

402
      if (length(private$valid_sets) <= 0L) {
403
404
        return(ret)
      }
405

406
      for (i in seq_along(private$valid_sets)) {
407
408
        ret <- append(
          x = ret
409
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
410
        )
Guolin Ke's avatar
Guolin Ke committed
411
      }
412

413
      return(ret)
414

Guolin Ke's avatar
Guolin Ke committed
415
    },
416

417
    # Save model
418
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
419

420
421
      self$restore_handle()

422
423
424
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
425

426
427
      filename <- path.expand(filename)

428
429
      .Call(
        LGBM_BoosterSaveModel_R
430
431
        , private$handle
        , as.integer(num_iteration)
432
        , as.integer(feature_importance_type)
433
        , filename
434
      )
435

436
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
437
    },
438

439
440
441
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {

      self$restore_handle()
442

443
444
445
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
446

447
      model_str <- .Call(
448
          LGBM_BoosterSaveModelToString_R
449
450
451
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
452
453
      )

454
455
456
457
      if (as_char) {
        model_str <- rawToChar(model_str)
      }

458
      return(model_str)
459

460
    },
461

462
    # Dump model in memory
463
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
464

465
466
      self$restore_handle()

467
468
469
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
470

471
      model_str <- .Call(
472
473
474
475
476
477
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
      )

478
      return(model_str)
479

Guolin Ke's avatar
Guolin Ke committed
480
    },
481

482
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
483
    predict = function(data,
484
                       start_iteration = NULL,
485
486
487
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
488
                       predcontrib = FALSE,
489
                       header = FALSE,
490
                       params = list()) {
491

492
493
      self$restore_handle()

494
495
496
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
497

498
499
500
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
501

502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
      # possibly override keyword arguments with parameters
      #
      # NOTE: this length() check minimizes the latency introduced by these checks,
      #       for the common case where params is empty
      #
      # NOTE: doing this here instead of in Predictor$predict() to keep
      #       Predictor$predict() as fast as possible
      if (length(params) > 0L) {
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_raw_score"
          , params = params
          , alternative_kwarg_value = rawscore
        )
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_leaf_index"
          , params = params
          , alternative_kwarg_value = predleaf
        )
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_contrib"
          , params = params
          , alternative_kwarg_value = predcontrib
        )
        rawscore <- params[["predict_raw_score"]]
        predleaf <- params[["predict_leaf_index"]]
        predcontrib <- params[["predict_contrib"]]
      }

530
      # Predict on new data
531
532
533
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
534
        , fast_predict_config = private$fast_predict_config
535
      )
536
537
      return(
        predictor$predict(
538
539
540
541
542
543
544
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
545
        )
546
      )
547

548
    },
549

550
551
    # Transform into predictor
    to_predictor = function() {
552
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
553
    },
554

555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
    configure_fast_predict = function(csr = FALSE,
                                      start_iteration = NULL,
                                      num_iteration = NULL,
                                      rawscore = FALSE,
                                      predleaf = FALSE,
                                      predcontrib = FALSE,
                                      params = list()) {

      self$restore_handle()
      ncols <- .Call(LGBM_BoosterGetNumFeature_R, private$handle)

      if (is.null(num_iteration)) {
        num_iteration <- -1L
      }
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }

      if (!csr) {
        fun <- LGBM_BoosterPredictForMatSingleRowFastInit_R
      } else {
        fun <- LGBM_BoosterPredictForCSRSingleRowFastInit_R
      }

      fast_handle <- .Call(
        fun
        , private$handle
        , ncols
        , rawscore
        , predleaf
        , predcontrib
        , start_iteration
        , num_iteration
        , lgb.params2str(params = params)
      )

      private$fast_predict_config <- list(
        handle = fast_handle
        , csr = as.logical(csr)
        , ncols = ncols
        , start_iteration = start_iteration
        , num_iteration = num_iteration
        , rawscore = as.logical(rawscore)
        , predleaf = as.logical(predleaf)
        , predcontrib = as.logical(predcontrib)
        , params = params
      )

      return(invisible(NULL))
    },

606
607
    # Used for serialization
    raw = NULL,
608

609
610
611
612
613
614
    # Store serialized raw bytes in model object
    save_raw = function() {
      if (is.null(self$raw)) {
        self$raw <- self$save_model_to_string(NULL, as_char = FALSE)
      }
      return(invisible(NULL))
615

616
    },
617

618
619
    drop_raw = function() {
      self$raw <- NULL
620
      return(invisible(NULL))
621
    },
622

623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
    check_null_handle = function() {
      return(lgb.is.null.handle(private$handle))
    },

    restore_handle = function() {
      if (self$check_null_handle()) {
        if (is.null(self$raw)) {
          .Call(LGBM_NullBoosterHandleError_R)
        }
        private$handle <- .Call(LGBM_BoosterLoadModelFromString_R, self$raw)
      }
      return(invisible(NULL))
    },

    get_handle = function() {
      return(private$handle)
639
    }
640

Guolin Ke's avatar
Guolin Ke committed
641
642
  ),
  private = list(
643
644
645
646
647
648
649
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
650
651
    num_class = 1L,
    num_dataset = 0L,
652
653
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
654
    higher_better_inner_eval = NULL,
655
    set_objective_to_none = FALSE,
656
    train_set_version = 0L,
657
    fast_predict_config = list(),
658
659
    # Predict data
    inner_predict = function(idx) {
660

661
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
662
      data_name <- private$name_train_set
663

664
665
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
666
      }
667

668
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
669
670
671
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
672

673
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
674
      if (is.null(private$predict_buffer[[data_name]])) {
675

676
        # Store predictions
677
        npred <- 0L
678
679
        .Call(
          LGBM_BoosterGetNumPredict_R
680
          , private$handle
681
          , as.integer(idx - 1L)
682
          , npred
683
        )
684
        private$predict_buffer[[data_name]] <- numeric(npred)
685

Guolin Ke's avatar
Guolin Ke committed
686
      }
687

688
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
689
      if (!private$is_predicted_cur_iter[[idx]]) {
690

691
        # Use buffer
692
693
        .Call(
          LGBM_BoosterGetPredict_R
694
          , private$handle
695
          , as.integer(idx - 1L)
696
          , private$predict_buffer[[data_name]]
697
        )
Guolin Ke's avatar
Guolin Ke committed
698
699
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
700

701
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
702
    },
703

704
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
705
    get_eval_info = function() {
706

Guolin Ke's avatar
Guolin Ke committed
707
      if (is.null(private$eval_names)) {
708
        eval_names <- .Call(
709
          LGBM_BoosterGetEvalNames_R
710
711
          , private$handle
        )
712

713
        if (length(eval_names) > 0L) {
714

715
          # Parse and store privately names
716
          private$eval_names <- eval_names
717
718
719

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
720
          metric_names <- gsub("@.*", "", eval_names)
721
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
722

Guolin Ke's avatar
Guolin Ke committed
723
        }
724

Guolin Ke's avatar
Guolin Ke committed
725
      }
726

727
      return(private$eval_names)
728

Guolin Ke's avatar
Guolin Ke committed
729
    },
730

731
732
733
734
735
736
737
738
739
740
741
742
743
744
    get_loaded_param = function(handle) {
      params_str <- .Call(
        LGBM_BoosterGetLoadedParam_R
        , handle
      )
      params <- jsonlite::fromJSON(params_str)
      if ("interaction_constraints" %in% names(params)) {
        params[["interaction_constraints"]] <- lapply(params[["interaction_constraints"]], function(x) x + 1L)
      }

      return(params)

    },

Guolin Ke's avatar
Guolin Ke committed
745
    inner_eval = function(data_name, data_idx, feval = NULL) {
746

747
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
748
749
750
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
751

752
753
      self$restore_handle()

Guolin Ke's avatar
Guolin Ke committed
754
      private$get_eval_info()
755

Guolin Ke's avatar
Guolin Ke committed
756
      ret <- list()
757

758
      if (length(private$eval_names) > 0L) {
759

760
761
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
762
763
        .Call(
          LGBM_BoosterGetEval_R
764
          , private$handle
765
          , as.integer(data_idx - 1L)
766
          , tmp_vals
767
        )
768

769
        for (i in seq_along(private$eval_names)) {
770

771
772
773
774
775
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
776
          res$higher_better <- private$higher_better_inner_eval[i]
777
          ret <- append(ret, list(res))
778

Guolin Ke's avatar
Guolin Ke committed
779
        }
780

Guolin Ke's avatar
Guolin Ke committed
781
      }
782

783
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
784
      if (!is.null(feval)) {
785

786
        # Check if evaluation metric is a function
787
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
788
789
          stop("lgb.Booster.eval: feval should be a function")
        }
790

Guolin Ke's avatar
Guolin Ke committed
791
        data <- private$train_set
792

793
        # Check if data to assess is existing differently
794
795
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
796
        }
797

798
        # Perform function evaluation
799
        res <- feval(private$inner_predict(data_idx), data)
800

801
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
802
803
804
          stop(
            "lgb.Booster.eval: custom eval function should return a list with attribute (name, value, higher_better)"
          )
805
        }
806

807
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
808
        res$data_name <- data_name
809
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
810
      }
811

812
      return(ret)
813

Guolin Ke's avatar
Guolin Ke committed
814
    }
815

Guolin Ke's avatar
Guolin Ke committed
816
817
818
  )
)

819
#' @name lgb_predict_shared_params
820
821
822
823
824
#' @param type Type of prediction to output. Allowed types are:\itemize{
#'             \item \code{"response"}: will output the predicted score according to the objective function being
#'                   optimized (depending on the link function that the objective uses), after applying any necessary
#'                   transformations - for example, for \code{objective="binary"}, it will output class probabilities.
#'             \item \code{"class"}: for classification objectives, will output the class with the highest predicted
825
826
827
#'                   probability. For other objectives, will output the same as "response". Note that \code{"class"} is
#'                   not a supported type for \link{lgb.configure_fast_predict} (see the documentation of that function
#'                   for more details).
828
829
830
831
832
833
834
#'             \item \code{"raw"}: will output the non-transformed numbers (sum of predictions from boosting iterations'
#'                   results) from which the "response" number is produced for a given objective function - for example,
#'                   for \code{objective="binary"}, this corresponds to log-odds. For many objectives such as
#'                   "regression", since no transformation is applied, the output will be the same as for "response".
#'             \item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls
#'                   in each tree in the model, outputted as integers, with one column per tree.
#'             \item \code{"contrib"}: will return the per-feature contributions for each prediction, including an
835
#'                   intercept (each feature will produce one column).
836
837
838
839
#'             }
#'
#'             Note that, if using custom objectives, types "class" and "response" will not be available and will
#'             default towards using "raw" instead.
840
841
842
843
844
845
846
847
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
848
849
850
#' @param params a list of additional named parameters. See
#'               \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#'               the "Predict Parameters" section of the documentation} for a list of parameters and
851
852
#'               valid values. Where these conflict with the values of keyword arguments to this function,
#'               the values in \code{params} take precedence.
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
NULL

#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
#' @details If the model object has been configured for fast single-row predictions through
#'          \link{lgb.configure_fast_predict}, this function will use the prediction parameters
#'          that were configured for it - as such, extra prediction parameters should not be passed
#'          here, otherwise the configuration will be ignored and the slow route will be taken.
#' @inheritParams lgb_predict_shared_params
#' @param object Object of class \code{lgb.Booster}
#' @param newdata a \code{matrix} object, a \code{dgCMatrix}, a \code{dgRMatrix} object, a \code{dsparseVector} object,
#'                or a character representing a path to a text file (CSV, TSV, or LibSVM).
#'
#'                For sparse inputs, if predictions are only going to be made for a single row, it will be faster to
#'                use CSR format, in which case the data may be passed as either a single-row CSR matrix (class
#'                \code{dgRMatrix} from package \code{Matrix}) or as a sparse numeric vector (class
#'                \code{dsparseVector} from package \code{Matrix}).
#'
#'                If single-row predictions are going to be performed frequently, it is recommended to
#'                pre-configure the model object for fast single-row sparse predictions through function
#'                \link{lgb.configure_fast_predict}.
#' @param header only used for prediction for text file. True if text file has header
876
#' @param ... ignored
877
#' @return For prediction types that are meant to always return one output per observation (e.g. when predicting
878
879
#'         \code{type="response"} or \code{type="raw"} on a binary classification or regression objective), will
#'         return a vector with one element per row in \code{newdata}.
880
#'
881
#'         For prediction types that are meant to return more than one output per observation (e.g. when predicting
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
#'         \code{type="response"} or \code{type="raw"} on a multi-class objective, or when predicting
#'         \code{type="leaf"}, regardless of objective), will return a matrix with one row per observation in
#'         \code{newdata} and one column per output.
#'
#'         For \code{type="leaf"} predictions, will return a matrix with one row per observation in \code{newdata}
#'         and one column per tree. Note that for multiclass objectives, LightGBM trains one tree per class at each
#'         boosting iteration. That means that, for example, for a multiclass model with 3 classes, the leaf
#'         predictions for the first class can be found in columns 1, 4, 7, 10, etc.
#'
#'         For \code{type="contrib"}, will return a matrix of SHAP values with one row per observation in
#'         \code{newdata} and columns corresponding to features. For regression, ranking, cross-entropy, and binary
#'         classification objectives, this matrix contains one column per feature plus a final column containing the
#'         Shapley base value. For multiclass objectives, this matrix will represent \code{num_classes} such matrices,
#'         in the order "feature contributions for first class, feature contributions for second class, feature
#'         contributions for third class, etc.".
897
#'
Guolin Ke's avatar
Guolin Ke committed
898
#' @examples
899
#' \donttest{
900
901
902
903
904
905
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
906
907
908
909
910
911
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
912
#' valids <- list(test = dtest)
913
914
915
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
916
#'   , nrounds = 5L
917
918
#'   , valids = valids
#' )
919
#' preds <- predict(model, test$data)
920
921
#'
#' # pass other prediction parameters
922
#' preds <- predict(
923
924
925
926
927
928
#'     model,
#'     test$data,
#'     params = list(
#'         predict_disable_shape_check = TRUE
#'    )
#' )
929
#' }
930
#' @importFrom utils modifyList
Guolin Ke's avatar
Guolin Ke committed
931
#' @export
James Lamb's avatar
James Lamb committed
932
predict.lgb.Booster <- function(object,
933
                                newdata,
934
                                type = "response",
935
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
936
937
                                num_iteration = NULL,
                                header = FALSE,
938
                                params = list(),
James Lamb's avatar
James Lamb committed
939
                                ...) {
940

941
  if (!lgb.is.Booster(x = object)) {
942
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
943
  }
944

945
946
  additional_params <- list(...)
  if (length(additional_params) > 0L) {
947
948
    additional_params_names <- names(additional_params)
    if ("reshape" %in% additional_params_names) {
949
950
      stop("'reshape' argument is no longer supported.")
    }
951
952
953
954
955
956
957
958
959
960
961
962
963
964

    old_args_for_type <- list(
      "rawscore" = "raw"
      , "predleaf" = "leaf"
      , "predcontrib" = "contrib"
    )
    for (arg in names(old_args_for_type)) {
      if (arg %in% additional_params_names) {
        stop(sprintf("Argument '%s' is no longer supported. Use type='%s' instead."
                     , arg
                     , old_args_for_type[[arg]]))
      }
    }

965
966
    warning(paste0(
      "predict.lgb.Booster: Found the following passed through '...': "
967
      , toString(names(additional_params))
968
      , ". These are ignored. Use argument 'params' instead."
969
970
971
    ))
  }

972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
  if (!is.null(object$params$objective) && object$params$objective == "none" && type %in% c("class", "response")) {
    warning("Prediction types 'class' and 'response' are not supported for custom objectives.")
    type <- "raw"
  }

  rawscore <- FALSE
  predleaf <- FALSE
  predcontrib <- FALSE
  if (type == "raw") {
    rawscore <- TRUE
  } else if (type == "leaf") {
    predleaf <- TRUE
  } else if (type == "contrib") {
    predcontrib <- TRUE
  }

  pred <- object$predict(
    data = newdata
    , start_iteration = start_iteration
    , num_iteration = num_iteration
    , rawscore = rawscore
    , predleaf =  predleaf
    , predcontrib =  predcontrib
    , header = header
    , params = params
997
  )
998
999
1000
1001
1002
1003
1004
1005
  if (type == "class") {
    if (object$params$objective == "binary") {
      pred <- as.integer(pred >= 0.5)
    } else if (object$params$objective %in% c("multiclass", "multiclassova")) {
      pred <- max.col(pred) - 1L
    }
  }
  return(pred)
Guolin Ke's avatar
Guolin Ke committed
1006
1007
}

1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
#' @title Configure Fast Single-Row Predictions
#' @description Pre-configures a LightGBM model object to produce fast single-row predictions
#'              for a given input data type, prediction type, and parameters.
#' @details Calling this function multiple times with different parameters might not override
#'          the previous configuration and might trigger undefined behavior.
#'
#'          Any saved configuration for fast predictions might be lost after making a single-row
#'          prediction of a different type than what was configured (except for types "response" and
#'          "class", which can be switched between each other at any time without losing the configuration).
#'
#'          In some situations, setting a fast prediction configuration for one type of prediction
#'          might cause the prediction function to keep using that configuration for single-row
#'          predictions even if the requested type of prediction is different from what was configured.
#'
#'          Note that this function will not accept argument \code{type="class"} - for such cases, one
#'          can pass \code{type="response"} to this function and then \code{type="class"} to the
#'          \code{predict} function - the fast configuration will not be lost or altered if the switch
#'          is between "response" and "class".
#'
#'          The configuration does not survive de-serializations, so it has to be generated
#'          anew in every R process that is going to use it (e.g. if loading a model object
#'          through \code{readRDS}, whatever configuration was there previously will be lost).
#'
#'          Requesting a different prediction type or passing parameters to \link{predict.lgb.Booster}
#'          will cause it to ignore the fast-predict configuration and take the slow route instead
#'          (but be aware that an existing configuration might not always be overriden by supplying
#'          different parameters or prediction type, so make sure to check that the output is what
#'          was expected when a prediction is to be made on a single row for something different than
#'          what is configured).
#'
#'          Note that, if configuring a non-default prediction type (such as leaf indices),
#'          then that type must also be passed in the call to \link{predict.lgb.Booster} in
#'          order for it to use the configuration. This also applies for \code{start_iteration}
#'          and \code{num_iteration}, but \bold{the \code{params} list must be empty} in the call to \code{predict}.
#'
#'          Predictions about feature contributions do not allow a fast route for CSR inputs,
#'          and as such, this function will produce an error if passing \code{csr=TRUE} and
#'          \code{type = "contrib"} together.
#' @inheritParams lgb_predict_shared_params
#' @param model LighGBM model object (class \code{lgb.Booster}).
#'
#'              \bold{The object will be modified in-place}.
#' @param csr Whether the prediction function is going to be called on sparse CSR inputs.
#'            If \code{FALSE}, will be assumed that predictions are going to be called on single-row
#'            regular R matrices.
#' @return The same \code{model} that was passed as input, invisibly, with the desired
#'         configuration stored inside it and available to be used in future calls to
#'         \link{predict.lgb.Booster}.
#' @examples
#' \donttest{
#' library(lightgbm)
#' data(mtcars)
#' X <- as.matrix(mtcars[, -1L])
#' y <- mtcars[, 1L]
#' dtrain <- lgb.Dataset(X, label = y, params = list(max_bin = 5L))
#' params <- list(min_data_in_leaf = 2L)
#' model <- lgb.train(
#'   params = params
#'  , data = dtrain
#'  , obj = "regression"
#'  , nrounds = 5L
#'  , verbose = -1L
#' )
#' lgb.configure_fast_predict(model)
#'
#' x_single <- X[11L, , drop = FALSE]
#' predict(model, x_single)
#'
#' # Will not use it if the prediction to be made
#' # is different from what was configured
#' predict(model, x_single, type = "leaf")
#' }
#' @export
lgb.configure_fast_predict <- function(model,
                                       csr = FALSE,
                                       start_iteration = NULL,
                                       num_iteration = NULL,
                                       type = "response",
                                       params = list()) {
  if (!lgb.is.Booster(x = model)) {
    stop("lgb.configure_fast_predict: model should be an ", sQuote("lgb.Booster"))
  }
  if (type == "class") {
    stop("type='class' is not supported for 'lgb.configure_fast_predict'. Use 'response' instead.")
  }

  rawscore <- FALSE
  predleaf <- FALSE
  predcontrib <- FALSE
  if (type == "raw") {
    rawscore <- TRUE
  } else if (type == "leaf") {
    predleaf <- TRUE
  } else if (type == "contrib") {
    predcontrib <- TRUE
  }

  if (csr && predcontrib) {
    stop("'lgb.configure_fast_predict' does not support feature contributions for CSR data.")
  }
  model$configure_fast_predict(
    csr = csr
    , start_iteration = start_iteration
    , num_iteration = num_iteration
    , rawscore = rawscore
    , predleaf = predleaf
    , predcontrib = predcontrib
    , params = params
  )
  return(invisible(model))
}

1120
1121
1122
1123
1124
#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
1125
#' @return The same input \code{x}, returned as invisible.
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
#' @export
print.lgb.Booster <- function(x, ...) {
  # nolint start
  handle <- x$.__enclos_env__$private$handle
  handle_is_null <- lgb.is.null.handle(handle)

  if (!handle_is_null) {
    ntrees <- x$current_iter()
    if (ntrees == 1L) {
      cat("LightGBM Model (1 tree)\n")
    } else {
      cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
    }
  } else {
    cat("LightGBM Model\n")
  }

  if (!handle_is_null) {
    obj <- x$params$objective
    if (obj == "none") {
      obj <- "custom"
    }
1148
1149
    num_class <- x$.__enclos_env__$private$num_class
    if (num_class == 1L) {
1150
1151
1152
1153
      cat(sprintf("Objective: %s\n", obj))
    } else {
      cat(sprintf("Objective: %s (%d classes)\n"
          , obj
1154
          , num_class))
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
    }
  } else {
    cat("(Booster handle is invalid)\n")
  }

  if (!handle_is_null) {
    ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
    cat(sprintf("Fitted to dataset with %d columns\n", ncols))
  }
  # nolint end

  return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
1174
#' @return The same input \code{object}, returned as invisible.
1175
1176
1177
1178
1179
#' @export
summary.lgb.Booster <- function(object, ...) {
  print(object)
}

1180
1181
#' @name lgb.load
#' @title Load LightGBM model
1182
1183
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
1184
#' @param filename path of model file
1185
#' @param model_str a str containing the model (as a \code{character} or \code{raw} vector)
1186
#'
1187
#' @return lgb.Booster
1188
#'
Guolin Ke's avatar
Guolin Ke committed
1189
#' @examples
1190
#' \donttest{
1191
1192
1193
1194
1195
1196
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1197
1198
1199
1200
1201
1202
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1203
#' valids <- list(test = dtest)
1204
1205
1206
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1207
#'   , nrounds = 5L
1208
#'   , valids = valids
1209
#'   , early_stopping_rounds = 3L
1210
#' )
1211
1212
1213
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
1214
1215
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
1216
#' }
Guolin Ke's avatar
Guolin Ke committed
1217
#' @export
1218
lgb.load <- function(filename = NULL, model_str = NULL) {
1219

1220
1221
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
1222

1223
1224
1225
1226
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
1227
    filename <- path.expand(filename)
1228
1229
1230
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
1231
1232
    return(invisible(Booster$new(modelfile = filename)))
  }
1233

1234
  if (model_str_provided) {
1235
1236
    if (!is.raw(model_str) && !is.character(model_str)) {
      stop("lgb.load: model_str should be a character/raw vector")
1237
    }
1238
1239
    return(invisible(Booster$new(model_str = model_str)))
  }
1240

1241
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
1242
1243
}

1244
1245
1246
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
1247
1248
1249
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1250
#'
1251
#' @return lgb.Booster
1252
#'
Guolin Ke's avatar
Guolin Ke committed
1253
#' @examples
1254
#' \donttest{
1255
1256
1257
1258
1259
1260
1261
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1262
1263
1264
1265
1266
1267
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1268
#' valids <- list(test = dtest)
1269
1270
1271
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1272
#'   , nrounds = 10L
1273
#'   , valids = valids
1274
#'   , early_stopping_rounds = 5L
1275
#' )
1276
#' lgb.save(model, tempfile(fileext = ".txt"))
1277
#' }
Guolin Ke's avatar
Guolin Ke committed
1278
#' @export
1279
lgb.save <- function(booster, filename, num_iteration = NULL) {
1280

1281
  if (!lgb.is.Booster(x = booster)) {
1282
1283
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1284

1285
1286
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
1287
  }
1288
  filename <- path.expand(filename)
1289

1290
  # Store booster
1291
1292
1293
1294
1295
1296
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
1297

Guolin Ke's avatar
Guolin Ke committed
1298
1299
}

1300
1301
1302
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
1303
1304
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1305
#'
Guolin Ke's avatar
Guolin Ke committed
1306
#' @return json format of model
1307
#'
Guolin Ke's avatar
Guolin Ke committed
1308
#' @examples
1309
#' \donttest{
1310
1311
1312
1313
1314
1315
1316
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1317
1318
1319
1320
1321
1322
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1323
#' valids <- list(test = dtest)
1324
1325
1326
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1327
#'   , nrounds = 10L
1328
#'   , valids = valids
1329
#'   , early_stopping_rounds = 5L
1330
#' )
1331
#' json_model <- lgb.dump(model)
1332
#' }
Guolin Ke's avatar
Guolin Ke committed
1333
#' @export
1334
lgb.dump <- function(booster, num_iteration = NULL) {
1335

1336
  if (!lgb.is.Booster(x = booster)) {
1337
    stop("lgb.dump: booster should be an ", sQuote("lgb.Booster"))
1338
  }
1339

1340
  # Return booster at requested iteration
1341
  return(booster$dump_model(num_iteration =  num_iteration))
1342

Guolin Ke's avatar
Guolin Ke committed
1343
1344
}

1345
1346
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
1347
1348
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
1349
#' @param booster Object of class \code{lgb.Booster}
1350
1351
1352
1353
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
1354
#' @param is_err TRUE will return evaluation error instead
1355
#'
1356
#' @return numeric vector of evaluation result
1357
#'
1358
#' @examples
1359
#' \donttest{
1360
#' # train a regression model
1361
1362
1363
1364
1365
1366
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1367
1368
1369
1370
1371
1372
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1373
#' valids <- list(test = dtest)
1374
1375
1376
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1377
#'   , nrounds = 5L
1378
1379
#'   , valids = valids
#' )
1380
1381
1382
1383
1384
1385
1386
1387
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1388
#' lgb.get.eval.result(model, "test", "l2")
1389
#' }
Guolin Ke's avatar
Guolin Ke committed
1390
#' @export
1391
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1392

1393
  if (!lgb.is.Booster(x = booster)) {
1394
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1395
  }
1396

1397
1398
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1399
  }
1400

1401
1402
1403
1404
1405
1406
1407
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
1408
      , toString(data_names)
1409
1410
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1411
  }
1412

1413
  # Check if evaluation result is existing
1414
1415
1416
1417
1418
1419
1420
1421
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
1422
      , toString(eval_names)
1423
1424
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1425
1426
    stop("lgb.get.eval.result: wrong eval name")
  }
1427

1428
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1429

1430
  # Check if error is requested
1431
  if (is_err) {
1432
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1433
  }
1434

1435
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1436
1437
    return(as.numeric(result))
  }
1438

1439
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1440
  iters <- as.integer(iters)
1441
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1442
  iters <- iters - delta
1443

1444
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1445
}