lgb.Booster.R 34.9 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom R6 R6Class
2
#' @importFrom utils modifyList
James Lamb's avatar
James Lamb committed
3
Booster <- R6::R6Class(
4
  classname = "lgb.Booster",
5
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
6
  public = list(
7

8
    best_iter = -1L,
9
    best_score = NA_real_,
10
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
11
    record_evals = list(),
12

13
14
    # Finalize will free up the handles
    finalize = function() {
15
16
17
18
19
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
20
      return(invisible(NULL))
21
    },
22

23
24
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
25
26
                          train_set = NULL,
                          modelfile = NULL,
27
                          model_str = NULL) {
28

29
      handle <- NULL
30

31
      if (!is.null(train_set)) {
32

33
34
35
36
37
38
39
40
41
42
43
44
        if (!lgb.is.Dataset(train_set)) {
          stop("lgb.Booster: Can only use lgb.Dataset as training data")
        }
        train_set_handle <- train_set$.__enclos_env__$private$get_handle()
        params <- utils::modifyList(params, train_set$get_params())
        params_str <- lgb.params2str(params = params)
        # Store booster handle
        handle <- .Call(
          LGBM_BoosterCreate_R
          , train_set_handle
          , params_str
        )
45

46
47
48
49
50
        # Create private booster information
        private$train_set <- train_set
        private$train_set_version <- train_set$.__enclos_env__$private$version
        private$num_dataset <- 1L
        private$init_predictor <- train_set$.__enclos_env__$private$predictor
51

52
        if (!is.null(private$init_predictor)) {
53

54
55
56
57
58
          # Merge booster
          .Call(
            LGBM_BoosterMerge_R
            , handle
            , private$init_predictor$.__enclos_env__$private$handle
59
          )
60

61
        }
62

63
64
        # Check current iteration
        private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
65

66
      } else if (!is.null(modelfile)) {
67

68
69
70
71
        # Do we have a model file as character?
        if (!is.character(modelfile)) {
          stop("lgb.Booster: Can only use a string as model file path")
        }
72

73
        modelfile <- path.expand(modelfile)
74

75
76
77
78
79
        # Create booster from model
        handle <- .Call(
          LGBM_BoosterCreateFromModelfile_R
          , modelfile
        )
80

81
      } else if (!is.null(model_str)) {
82

83
84
85
86
        # Do we have a model_str as character/raw?
        if (!is.raw(model_str) && !is.character(model_str)) {
          stop("lgb.Booster: Can only use a character/raw vector as model_str")
        }
87

88
89
90
91
92
        # Create booster from model
        handle <- .Call(
          LGBM_BoosterLoadModelFromString_R
          , model_str
        )
93

Guolin Ke's avatar
Guolin Ke committed
94
      } else {
95

96
97
98
99
        # Booster non existent
        stop(
          "lgb.Booster: Need at least either training dataset, "
          , "model file, or model_str to create booster instance"
100
        )
101

Guolin Ke's avatar
Guolin Ke committed
102
      }
103

104
105
106
107
108
109
110
111
112
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
      private$num_class <- 1L
      .Call(
        LGBM_BoosterGetNumClasses_R
        , private$handle
        , private$num_class
      )

113
114
      self$params <- params

115
116
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
117
    },
118

119
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
120
    set_train_data_name = function(name) {
121

122
      # Set name
Guolin Ke's avatar
Guolin Ke committed
123
      private$name_train_set <- name
124
      return(invisible(self))
125

Guolin Ke's avatar
Guolin Ke committed
126
    },
127

128
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
129
    add_valid = function(data, name) {
130

131
      if (!lgb.is.Dataset(data)) {
132
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
133
      }
134

Guolin Ke's avatar
Guolin Ke committed
135
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
136
137
138
139
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
140
      }
141

142
143
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
144
      }
145

146
      # Add validation data to booster
147
148
      .Call(
        LGBM_BoosterAddValidData_R
149
150
151
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
152

153
154
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
155
      private$num_dataset <- private$num_dataset + 1L
156
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
157

158
      return(invisible(self))
159

Guolin Ke's avatar
Guolin Ke committed
160
    },
161

162
    reset_parameter = function(params) {
163

164
      if (methods::is(self$params, "list")) {
165
        params <- utils::modifyList(self$params, params)
166
167
      }

168
      params_str <- lgb.params2str(params = params)
169

170
171
      self$restore_handle()

172
173
      .Call(
        LGBM_BoosterResetParameter_R
174
175
176
        , private$handle
        , params_str
      )
177
      self$params <- params
178

179
      return(invisible(self))
180

Guolin Ke's avatar
Guolin Ke committed
181
    },
182

183
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
184
    update = function(train_set = NULL, fobj = NULL) {
185

186
187
188
189
190
191
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

Guolin Ke's avatar
Guolin Ke committed
192
      if (!is.null(train_set)) {
193

194
        if (!lgb.is.Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
195
196
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
197

Guolin Ke's avatar
Guolin Ke committed
198
        if (!identical(train_set$predictor, private$init_predictor)) {
199
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
200
        }
201

202
203
        .Call(
          LGBM_BoosterResetTrainingData_R
204
205
206
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
207

208
        private$train_set <- train_set
209
        private$train_set_version <- train_set$.__enclos_env__$private$version
210

Guolin Ke's avatar
Guolin Ke committed
211
      }
212

213
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
214
      if (is.null(fobj)) {
215
216
217
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
218
        # Boost iteration from known objective
219
220
        .Call(
          LGBM_BoosterUpdateOneIter_R
221
222
          , private$handle
        )
223

Guolin Ke's avatar
Guolin Ke committed
224
      } else {
225

226
227
228
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
229
        if (!private$set_objective_to_none) {
230
          self$reset_parameter(params = list(objective = "none"))
231
          private$set_objective_to_none <- TRUE
232
        }
233
        # Perform objective calculation
234
235
        preds <- private$inner_predict(1L)
        gpair <- fobj(preds, private$train_set)
236

237
        # Check for gradient and hessian as list
238
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
239
          stop("lgb.Booster.update: custom objective should
240
241
            return a list with attributes (hess, grad)")
        }
242

243
244
245
246
247
248
249
250
251
252
253
        # Check grad and hess have the right shape
        n_grad <- length(gpair$grad)
        n_hess <- length(gpair$hess)
        n_preds <- length(preds)
        if (n_grad != n_preds) {
          stop(sprintf("Expected custom objective function to return grad with length %d, got %d.", n_preds, n_grad))
        }
        if (n_hess != n_preds) {
          stop(sprintf("Expected custom objective function to return hess with length %d, got %d.", n_preds, n_hess))
        }

254
        # Return custom boosting gradient/hessian
255
256
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
257
258
259
          , private$handle
          , gpair$grad
          , gpair$hess
260
          , n_preds
261
        )
262

Guolin Ke's avatar
Guolin Ke committed
263
      }
264

265
      # Loop through each iteration
266
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
267
268
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
269

270
      return(invisible(self))
271

Guolin Ke's avatar
Guolin Ke committed
272
    },
273

274
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
275
    rollback_one_iter = function() {
276

277
278
      self$restore_handle()

279
280
      .Call(
        LGBM_BoosterRollbackOneIter_R
281
282
        , private$handle
      )
283

284
      # Loop through each iteration
285
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
286
287
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
288

289
      return(invisible(self))
290

Guolin Ke's avatar
Guolin Ke committed
291
    },
292

293
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
294
    current_iter = function() {
295

296
297
      self$restore_handle()

298
      cur_iter <- 0L
299
300
301
302
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
303
      )
304
      return(cur_iter)
305

Guolin Ke's avatar
Guolin Ke committed
306
    },
307

308
    # Get upper bound
309
    upper_bound = function() {
310

311
312
      self$restore_handle()

313
      upper_bound <- 0.0
314
315
316
317
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
318
      )
319
      return(upper_bound)
320
321
322
323

    },

    # Get lower bound
324
    lower_bound = function() {
325

326
327
      self$restore_handle()

328
      lower_bound <- 0.0
329
330
331
332
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
333
      )
334
      return(lower_bound)
335
336
337

    },

338
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
339
    eval = function(data, name, feval = NULL) {
340

341
      if (!lgb.is.Dataset(data)) {
342
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
343
      }
344

345
      # Check for identical data
346
      data_idx <- 0L
347
      if (identical(data, private$train_set)) {
348
        data_idx <- 1L
349
      } else {
350

351
        # Check for validation data
352
        if (length(private$valid_sets) > 0L) {
353

354
          for (i in seq_along(private$valid_sets)) {
355

356
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
357
            if (identical(data, private$valid_sets[[i]])) {
358

359
              # Found identical data, skip
360
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
361
              break
362

Guolin Ke's avatar
Guolin Ke committed
363
            }
364

Guolin Ke's avatar
Guolin Ke committed
365
          }
366

Guolin Ke's avatar
Guolin Ke committed
367
        }
368

Guolin Ke's avatar
Guolin Ke committed
369
      }
370

371
      # Check if evaluation was not done
372
      if (data_idx == 0L) {
373

374
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
375
376
        self$add_valid(data, name)
        data_idx <- private$num_dataset
377

Guolin Ke's avatar
Guolin Ke committed
378
      }
379

380
      # Evaluate data
381
382
383
384
385
386
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
387
      )
388

Guolin Ke's avatar
Guolin Ke committed
389
    },
390

391
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
392
    eval_train = function(feval = NULL) {
393
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
394
    },
395

396
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
397
    eval_valid = function(feval = NULL) {
398

399
      ret <- list()
400

401
      if (length(private$valid_sets) <= 0L) {
402
403
        return(ret)
      }
404

405
      for (i in seq_along(private$valid_sets)) {
406
407
        ret <- append(
          x = ret
408
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
409
        )
Guolin Ke's avatar
Guolin Ke committed
410
      }
411

412
      return(ret)
413

Guolin Ke's avatar
Guolin Ke committed
414
    },
415

416
    # Save model
417
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
418

419
420
      self$restore_handle()

421
422
423
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
424

425
426
      filename <- path.expand(filename)

427
428
      .Call(
        LGBM_BoosterSaveModel_R
429
430
        , private$handle
        , as.integer(num_iteration)
431
        , as.integer(feature_importance_type)
432
        , filename
433
      )
434

435
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
436
    },
437

438
439
440
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {

      self$restore_handle()
441

442
443
444
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
445

446
      model_str <- .Call(
447
          LGBM_BoosterSaveModelToString_R
448
449
450
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
451
452
      )

453
454
455
456
      if (as_char) {
        model_str <- rawToChar(model_str)
      }

457
      return(model_str)
458

459
    },
460

461
    # Dump model in memory
462
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
463

464
465
      self$restore_handle()

466
467
468
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
469

470
      model_str <- .Call(
471
472
473
474
475
476
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
      )

477
      return(model_str)
478

Guolin Ke's avatar
Guolin Ke committed
479
    },
480

481
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
482
    predict = function(data,
483
                       start_iteration = NULL,
484
485
486
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
487
                       predcontrib = FALSE,
488
                       header = FALSE,
489
                       params = list()) {
490

491
492
      self$restore_handle()

493
494
495
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
496

497
498
499
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
      # possibly override keyword arguments with parameters
      #
      # NOTE: this length() check minimizes the latency introduced by these checks,
      #       for the common case where params is empty
      #
      # NOTE: doing this here instead of in Predictor$predict() to keep
      #       Predictor$predict() as fast as possible
      if (length(params) > 0L) {
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_raw_score"
          , params = params
          , alternative_kwarg_value = rawscore
        )
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_leaf_index"
          , params = params
          , alternative_kwarg_value = predleaf
        )
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_contrib"
          , params = params
          , alternative_kwarg_value = predcontrib
        )
        rawscore <- params[["predict_raw_score"]]
        predleaf <- params[["predict_leaf_index"]]
        predcontrib <- params[["predict_contrib"]]
      }

529
      # Predict on new data
530
531
532
533
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
      )
534
535
      return(
        predictor$predict(
536
537
538
539
540
541
542
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
543
        )
544
      )
545

546
    },
547

548
549
    # Transform into predictor
    to_predictor = function() {
550
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
551
    },
552

553
554
    # Used for serialization
    raw = NULL,
555

556
557
558
559
560
561
    # Store serialized raw bytes in model object
    save_raw = function() {
      if (is.null(self$raw)) {
        self$raw <- self$save_model_to_string(NULL, as_char = FALSE)
      }
      return(invisible(NULL))
562

563
    },
564

565
566
    drop_raw = function() {
      self$raw <- NULL
567
      return(invisible(NULL))
568
    },
569

570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
    check_null_handle = function() {
      return(lgb.is.null.handle(private$handle))
    },

    restore_handle = function() {
      if (self$check_null_handle()) {
        if (is.null(self$raw)) {
          .Call(LGBM_NullBoosterHandleError_R)
        }
        private$handle <- .Call(LGBM_BoosterLoadModelFromString_R, self$raw)
      }
      return(invisible(NULL))
    },

    get_handle = function() {
      return(private$handle)
586
    }
587

Guolin Ke's avatar
Guolin Ke committed
588
589
  ),
  private = list(
590
591
592
593
594
595
596
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
597
598
    num_class = 1L,
    num_dataset = 0L,
599
600
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
601
    higher_better_inner_eval = NULL,
602
    set_objective_to_none = FALSE,
603
    train_set_version = 0L,
604
605
    # Predict data
    inner_predict = function(idx) {
606

607
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
608
      data_name <- private$name_train_set
609

610
611
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
612
      }
613

614
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
615
616
617
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
618

619
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
620
      if (is.null(private$predict_buffer[[data_name]])) {
621

622
        # Store predictions
623
        npred <- 0L
624
625
        .Call(
          LGBM_BoosterGetNumPredict_R
626
          , private$handle
627
          , as.integer(idx - 1L)
628
          , npred
629
        )
630
        private$predict_buffer[[data_name]] <- numeric(npred)
631

Guolin Ke's avatar
Guolin Ke committed
632
      }
633

634
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
635
      if (!private$is_predicted_cur_iter[[idx]]) {
636

637
        # Use buffer
638
639
        .Call(
          LGBM_BoosterGetPredict_R
640
          , private$handle
641
          , as.integer(idx - 1L)
642
          , private$predict_buffer[[data_name]]
643
        )
Guolin Ke's avatar
Guolin Ke committed
644
645
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
646

647
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
648
    },
649

650
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
651
    get_eval_info = function() {
652

Guolin Ke's avatar
Guolin Ke committed
653
      if (is.null(private$eval_names)) {
654
        eval_names <- .Call(
655
          LGBM_BoosterGetEvalNames_R
656
657
          , private$handle
        )
658

659
        if (length(eval_names) > 0L) {
660

661
          # Parse and store privately names
662
          private$eval_names <- eval_names
663
664
665

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
666
          metric_names <- gsub("@.*", "", eval_names)
667
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
668

Guolin Ke's avatar
Guolin Ke committed
669
        }
670

Guolin Ke's avatar
Guolin Ke committed
671
      }
672

673
      return(private$eval_names)
674

Guolin Ke's avatar
Guolin Ke committed
675
    },
676

Guolin Ke's avatar
Guolin Ke committed
677
    inner_eval = function(data_name, data_idx, feval = NULL) {
678

679
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
680
681
682
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
683

684
685
      self$restore_handle()

Guolin Ke's avatar
Guolin Ke committed
686
      private$get_eval_info()
687

Guolin Ke's avatar
Guolin Ke committed
688
      ret <- list()
689

690
      if (length(private$eval_names) > 0L) {
691

692
693
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
694
695
        .Call(
          LGBM_BoosterGetEval_R
696
          , private$handle
697
          , as.integer(data_idx - 1L)
698
          , tmp_vals
699
        )
700

701
        for (i in seq_along(private$eval_names)) {
702

703
704
705
706
707
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
708
          res$higher_better <- private$higher_better_inner_eval[i]
709
          ret <- append(ret, list(res))
710

Guolin Ke's avatar
Guolin Ke committed
711
        }
712

Guolin Ke's avatar
Guolin Ke committed
713
      }
714

715
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
716
      if (!is.null(feval)) {
717

718
        # Check if evaluation metric is a function
719
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
720
721
          stop("lgb.Booster.eval: feval should be a function")
        }
722

Guolin Ke's avatar
Guolin Ke committed
723
        data <- private$train_set
724

725
        # Check if data to assess is existing differently
726
727
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
728
        }
729

730
        # Perform function evaluation
731
        res <- feval(private$inner_predict(data_idx), data)
732

733
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
734
735
736
          stop(
            "lgb.Booster.eval: custom eval function should return a list with attribute (name, value, higher_better)"
          )
737
        }
738

739
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
740
        res$data_name <- data_name
741
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
742
      }
743

744
      return(ret)
745

Guolin Ke's avatar
Guolin Ke committed
746
    }
747

Guolin Ke's avatar
Guolin Ke committed
748
749
750
  )
)

751
752
753
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
754
#' @param object Object of class \code{lgb.Booster}
755
756
#' @param newdata a \code{matrix} object, a \code{dgCMatrix} object or
#'                a character representing a path to a text file (CSV, TSV, or LibSVM)
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
#' @param type Type of prediction to output. Allowed types are:\itemize{
#'             \item \code{"response"}: will output the predicted score according to the objective function being
#'                   optimized (depending on the link function that the objective uses), after applying any necessary
#'                   transformations - for example, for \code{objective="binary"}, it will output class probabilities.
#'             \item \code{"class"}: for classification objectives, will output the class with the highest predicted
#'                   probability. For other objectives, will output the same as "response".
#'             \item \code{"raw"}: will output the non-transformed numbers (sum of predictions from boosting iterations'
#'                   results) from which the "response" number is produced for a given objective function - for example,
#'                   for \code{objective="binary"}, this corresponds to log-odds. For many objectives such as
#'                   "regression", since no transformation is applied, the output will be the same as for "response".
#'             \item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls
#'                   in each tree in the model, outputted as integers, with one column per tree.
#'             \item \code{"contrib"}: will return the per-feature contributions for each prediction, including an
#'                   intercept (each feature will produce one column). If there are multiple classes, each class will
#'                   have separate feature contributions (thus the number of columns is features+1 multiplied by the
#'                   number of classes).
#'             }
#'
#'             Note that, if using custom objectives, types "class" and "response" will not be available and will
#'             default towards using "raw" instead.
777
778
779
780
781
782
783
784
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
Guolin Ke's avatar
Guolin Ke committed
785
#' @param header only used for prediction for text file. True if text file has header
786
787
788
#' @param params a list of additional named parameters. See
#'               \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#'               the "Predict Parameters" section of the documentation} for a list of parameters and
789
790
#'               valid values. Where these conflict with the values of keyword arguments to this function,
#'               the values in \code{params} take precedence.
791
#' @param ... ignored
792
793
794
#' @return For prediction types that are meant to always return one output per observation (e.g. when predicting
#'         \code{type="response"} on a binary classification or regression objective), will return a vector with one
#'         element per row in \code{newdata}.
795
#'
796
797
798
#'         For prediction types that are meant to return more than one output per observation (e.g. when predicting
#'         \code{type="response"} on a multi-class objective, or when predicting \code{type="leaf"}, regardless of
#'         objective), will return a matrix with one row per observation in \code{newdata} and one column per output.
799
#'
Guolin Ke's avatar
Guolin Ke committed
800
#' @examples
801
#' \donttest{
802
803
804
805
806
807
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
808
809
810
811
812
813
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
814
#' valids <- list(test = dtest)
815
816
817
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
818
#'   , nrounds = 5L
819
820
#'   , valids = valids
#' )
821
#' preds <- predict(model, test$data)
822
823
#'
#' # pass other prediction parameters
824
#' preds <- predict(
825
826
827
828
829
830
#'     model,
#'     test$data,
#'     params = list(
#'         predict_disable_shape_check = TRUE
#'    )
#' )
831
#' }
832
#' @importFrom utils modifyList
Guolin Ke's avatar
Guolin Ke committed
833
#' @export
James Lamb's avatar
James Lamb committed
834
predict.lgb.Booster <- function(object,
835
                                newdata,
836
                                type = "response",
837
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
838
839
                                num_iteration = NULL,
                                header = FALSE,
840
                                params = list(),
James Lamb's avatar
James Lamb committed
841
                                ...) {
842

843
  if (!lgb.is.Booster(x = object)) {
844
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
845
  }
846

847
848
  additional_params <- list(...)
  if (length(additional_params) > 0L) {
849
850
    additional_params_names <- names(additional_params)
    if ("reshape" %in% additional_params_names) {
851
852
      stop("'reshape' argument is no longer supported.")
    }
853
854
855
856
857
858
859
860
861
862
863
864
865
866

    old_args_for_type <- list(
      "rawscore" = "raw"
      , "predleaf" = "leaf"
      , "predcontrib" = "contrib"
    )
    for (arg in names(old_args_for_type)) {
      if (arg %in% additional_params_names) {
        stop(sprintf("Argument '%s' is no longer supported. Use type='%s' instead."
                     , arg
                     , old_args_for_type[[arg]]))
      }
    }

867
868
    warning(paste0(
      "predict.lgb.Booster: Found the following passed through '...': "
869
      , toString(names(additional_params))
870
      , ". These are ignored. Use argument 'params' instead."
871
872
873
    ))
  }

874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
  if (!is.null(object$params$objective) && object$params$objective == "none" && type %in% c("class", "response")) {
    warning("Prediction types 'class' and 'response' are not supported for custom objectives.")
    type <- "raw"
  }

  rawscore <- FALSE
  predleaf <- FALSE
  predcontrib <- FALSE
  if (type == "raw") {
    rawscore <- TRUE
  } else if (type == "leaf") {
    predleaf <- TRUE
  } else if (type == "contrib") {
    predcontrib <- TRUE
  }

  pred <- object$predict(
    data = newdata
    , start_iteration = start_iteration
    , num_iteration = num_iteration
    , rawscore = rawscore
    , predleaf =  predleaf
    , predcontrib =  predcontrib
    , header = header
    , params = params
899
  )
900
901
902
903
904
905
906
907
  if (type == "class") {
    if (object$params$objective == "binary") {
      pred <- as.integer(pred >= 0.5)
    } else if (object$params$objective %in% c("multiclass", "multiclassova")) {
      pred <- max.col(pred) - 1L
    }
  }
  return(pred)
Guolin Ke's avatar
Guolin Ke committed
908
909
}

910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `x`, returned as invisible.
#' @export
print.lgb.Booster <- function(x, ...) {
  # nolint start
  handle <- x$.__enclos_env__$private$handle
  handle_is_null <- lgb.is.null.handle(handle)

  if (!handle_is_null) {
    ntrees <- x$current_iter()
    if (ntrees == 1L) {
      cat("LightGBM Model (1 tree)\n")
    } else {
      cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
    }
  } else {
    cat("LightGBM Model\n")
  }

  if (!handle_is_null) {
    obj <- x$params$objective
    if (obj == "none") {
      obj <- "custom"
    }
938
939
    num_class <- x$.__enclos_env__$private$num_class
    if (num_class == 1L) {
940
941
942
943
      cat(sprintf("Objective: %s\n", obj))
    } else {
      cat(sprintf("Objective: %s (%d classes)\n"
          , obj
944
          , num_class))
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
    }
  } else {
    cat("(Booster handle is invalid)\n")
  }

  if (!handle_is_null) {
    ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
    cat(sprintf("Fitted to dataset with %d columns\n", ncols))
  }
  # nolint end

  return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `object`, returned as invisible.
#' @export
summary.lgb.Booster <- function(object, ...) {
  print(object)
}

970
971
#' @name lgb.load
#' @title Load LightGBM model
972
973
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
974
#' @param filename path of model file
975
#' @param model_str a str containing the model (as a `character` or `raw` vector)
976
#'
977
#' @return lgb.Booster
978
#'
Guolin Ke's avatar
Guolin Ke committed
979
#' @examples
980
#' \donttest{
981
982
983
984
985
986
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
987
988
989
990
991
992
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
993
#' valids <- list(test = dtest)
994
995
996
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
997
#'   , nrounds = 5L
998
#'   , valids = valids
999
#'   , early_stopping_rounds = 3L
1000
#' )
1001
1002
1003
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
1004
1005
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
1006
#' }
Guolin Ke's avatar
Guolin Ke committed
1007
#' @export
1008
lgb.load <- function(filename = NULL, model_str = NULL) {
1009

1010
1011
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
1012

1013
1014
1015
1016
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
1017
    filename <- path.expand(filename)
1018
1019
1020
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
1021
1022
    return(invisible(Booster$new(modelfile = filename)))
  }
1023

1024
  if (model_str_provided) {
1025
1026
    if (!is.raw(model_str) && !is.character(model_str)) {
      stop("lgb.load: model_str should be a character/raw vector")
1027
    }
1028
1029
    return(invisible(Booster$new(model_str = model_str)))
  }
1030

1031
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
1032
1033
}

1034
1035
1036
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
1037
1038
1039
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1040
#'
1041
#' @return lgb.Booster
1042
#'
Guolin Ke's avatar
Guolin Ke committed
1043
#' @examples
1044
#' \donttest{
1045
1046
1047
1048
1049
1050
1051
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1052
1053
1054
1055
1056
1057
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1058
#' valids <- list(test = dtest)
1059
1060
1061
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1062
#'   , nrounds = 10L
1063
#'   , valids = valids
1064
#'   , early_stopping_rounds = 5L
1065
#' )
1066
#' lgb.save(model, tempfile(fileext = ".txt"))
1067
#' }
Guolin Ke's avatar
Guolin Ke committed
1068
#' @export
1069
lgb.save <- function(booster, filename, num_iteration = NULL) {
1070

1071
  if (!lgb.is.Booster(x = booster)) {
1072
1073
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1074

1075
1076
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
1077
  }
1078
  filename <- path.expand(filename)
1079

1080
  # Store booster
1081
1082
1083
1084
1085
1086
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
1087

Guolin Ke's avatar
Guolin Ke committed
1088
1089
}

1090
1091
1092
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
1093
1094
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1095
#'
Guolin Ke's avatar
Guolin Ke committed
1096
#' @return json format of model
1097
#'
Guolin Ke's avatar
Guolin Ke committed
1098
#' @examples
1099
#' \donttest{
1100
1101
1102
1103
1104
1105
1106
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1107
1108
1109
1110
1111
1112
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1113
#' valids <- list(test = dtest)
1114
1115
1116
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1117
#'   , nrounds = 10L
1118
#'   , valids = valids
1119
#'   , early_stopping_rounds = 5L
1120
#' )
1121
#' json_model <- lgb.dump(model)
1122
#' }
Guolin Ke's avatar
Guolin Ke committed
1123
#' @export
1124
lgb.dump <- function(booster, num_iteration = NULL) {
1125

1126
  if (!lgb.is.Booster(x = booster)) {
1127
1128
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1129

1130
  # Return booster at requested iteration
1131
  return(booster$dump_model(num_iteration =  num_iteration))
1132

Guolin Ke's avatar
Guolin Ke committed
1133
1134
}

1135
1136
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
1137
1138
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
1139
#' @param booster Object of class \code{lgb.Booster}
1140
1141
1142
1143
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
1144
#' @param is_err TRUE will return evaluation error instead
1145
#'
1146
#' @return numeric vector of evaluation result
1147
#'
1148
#' @examples
1149
#' \donttest{
1150
#' # train a regression model
1151
1152
1153
1154
1155
1156
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1157
1158
1159
1160
1161
1162
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1163
#' valids <- list(test = dtest)
1164
1165
1166
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1167
#'   , nrounds = 5L
1168
1169
#'   , valids = valids
#' )
1170
1171
1172
1173
1174
1175
1176
1177
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1178
#' lgb.get.eval.result(model, "test", "l2")
1179
#' }
Guolin Ke's avatar
Guolin Ke committed
1180
#' @export
1181
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1182

1183
  if (!lgb.is.Booster(x = booster)) {
1184
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1185
  }
1186

1187
1188
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1189
  }
1190

1191
1192
1193
1194
1195
1196
1197
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
1198
      , toString(data_names)
1199
1200
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1201
  }
1202

1203
  # Check if evaluation result is existing
1204
1205
1206
1207
1208
1209
1210
1211
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
1212
      , toString(eval_names)
1213
1214
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1215
1216
    stop("lgb.get.eval.result: wrong eval name")
  }
1217

1218
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1219

1220
  # Check if error is requested
1221
  if (is_err) {
1222
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1223
  }
1224

1225
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1226
1227
    return(as.numeric(result))
  }
1228

1229
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1230
  iters <- as.integer(iters)
1231
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1232
  iters <- iters - delta
1233

1234
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1235
}