lgb.Booster.R 31.8 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom R6 R6Class
2
#' @importFrom utils modifyList
James Lamb's avatar
James Lamb committed
3
Booster <- R6::R6Class(
4
  classname = "lgb.Booster",
5
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
6
  public = list(
7

8
    best_iter = -1L,
9
    best_score = NA_real_,
10
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
11
    record_evals = list(),
12

13
14
    # Finalize will free up the handles
    finalize = function() {
15
16
17
18
19
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
20
      return(invisible(NULL))
21
    },
22

23
24
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
25
26
                          train_set = NULL,
                          modelfile = NULL,
27
                          model_str = NULL) {
28

29
      handle <- NULL
30

31
      if (!is.null(train_set)) {
32

33
34
35
36
37
38
39
40
41
42
43
44
        if (!lgb.is.Dataset(train_set)) {
          stop("lgb.Booster: Can only use lgb.Dataset as training data")
        }
        train_set_handle <- train_set$.__enclos_env__$private$get_handle()
        params <- utils::modifyList(params, train_set$get_params())
        params_str <- lgb.params2str(params = params)
        # Store booster handle
        handle <- .Call(
          LGBM_BoosterCreate_R
          , train_set_handle
          , params_str
        )
45

46
47
48
49
50
        # Create private booster information
        private$train_set <- train_set
        private$train_set_version <- train_set$.__enclos_env__$private$version
        private$num_dataset <- 1L
        private$init_predictor <- train_set$.__enclos_env__$private$predictor
51

52
        if (!is.null(private$init_predictor)) {
53

54
55
56
57
58
          # Merge booster
          .Call(
            LGBM_BoosterMerge_R
            , handle
            , private$init_predictor$.__enclos_env__$private$handle
59
          )
60

61
        }
62

63
64
        # Check current iteration
        private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
65

66
      } else if (!is.null(modelfile)) {
67

68
69
70
71
        # Do we have a model file as character?
        if (!is.character(modelfile)) {
          stop("lgb.Booster: Can only use a string as model file path")
        }
72

73
        modelfile <- path.expand(modelfile)
74

75
76
77
78
79
        # Create booster from model
        handle <- .Call(
          LGBM_BoosterCreateFromModelfile_R
          , modelfile
        )
80

81
      } else if (!is.null(model_str)) {
82

83
84
85
86
        # Do we have a model_str as character/raw?
        if (!is.raw(model_str) && !is.character(model_str)) {
          stop("lgb.Booster: Can only use a character/raw vector as model_str")
        }
87

88
89
90
91
92
        # Create booster from model
        handle <- .Call(
          LGBM_BoosterLoadModelFromString_R
          , model_str
        )
93

Guolin Ke's avatar
Guolin Ke committed
94
      } else {
95

96
97
98
99
        # Booster non existent
        stop(
          "lgb.Booster: Need at least either training dataset, "
          , "model file, or model_str to create booster instance"
100
        )
101

Guolin Ke's avatar
Guolin Ke committed
102
      }
103

104
105
106
107
108
109
110
111
112
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
      private$num_class <- 1L
      .Call(
        LGBM_BoosterGetNumClasses_R
        , private$handle
        , private$num_class
      )

113
114
      self$params <- params

115
116
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
117
    },
118

119
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
120
    set_train_data_name = function(name) {
121

122
      # Set name
Guolin Ke's avatar
Guolin Ke committed
123
      private$name_train_set <- name
124
      return(invisible(self))
125

Guolin Ke's avatar
Guolin Ke committed
126
    },
127

128
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
129
    add_valid = function(data, name) {
130

131
      if (!lgb.is.Dataset(data)) {
132
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
133
      }
134

Guolin Ke's avatar
Guolin Ke committed
135
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
136
137
138
139
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
140
      }
141

142
143
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
144
      }
145

146
      # Add validation data to booster
147
148
      .Call(
        LGBM_BoosterAddValidData_R
149
150
151
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
152

153
154
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
155
      private$num_dataset <- private$num_dataset + 1L
156
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
157

158
      return(invisible(self))
159

Guolin Ke's avatar
Guolin Ke committed
160
    },
161

162
    reset_parameter = function(params) {
163

164
      if (methods::is(self$params, "list")) {
165
        params <- utils::modifyList(self$params, params)
166
167
      }

168
      params_str <- lgb.params2str(params = params)
169

170
171
      self$restore_handle()

172
173
      .Call(
        LGBM_BoosterResetParameter_R
174
175
176
        , private$handle
        , params_str
      )
177
      self$params <- params
178

179
      return(invisible(self))
180

Guolin Ke's avatar
Guolin Ke committed
181
    },
182

183
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
184
    update = function(train_set = NULL, fobj = NULL) {
185

186
187
188
189
190
191
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

Guolin Ke's avatar
Guolin Ke committed
192
      if (!is.null(train_set)) {
193

194
        if (!lgb.is.Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
195
196
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
197

Guolin Ke's avatar
Guolin Ke committed
198
        if (!identical(train_set$predictor, private$init_predictor)) {
199
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
200
        }
201

202
203
        .Call(
          LGBM_BoosterResetTrainingData_R
204
205
206
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
207

208
        private$train_set <- train_set
209
        private$train_set_version <- train_set$.__enclos_env__$private$version
210

Guolin Ke's avatar
Guolin Ke committed
211
      }
212

213
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
214
      if (is.null(fobj)) {
215
216
217
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
218
        # Boost iteration from known objective
219
220
        .Call(
          LGBM_BoosterUpdateOneIter_R
221
222
          , private$handle
        )
223

Guolin Ke's avatar
Guolin Ke committed
224
      } else {
225

226
227
228
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
229
        if (!private$set_objective_to_none) {
230
          self$reset_parameter(params = list(objective = "none"))
231
          private$set_objective_to_none <- TRUE
232
        }
233
        # Perform objective calculation
234
        gpair <- fobj(private$inner_predict(1L), private$train_set)
235

236
        # Check for gradient and hessian as list
237
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
238
          stop("lgb.Booster.update: custom objective should
239
240
            return a list with attributes (hess, grad)")
        }
241

242
        # Return custom boosting gradient/hessian
243
244
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
245
246
247
248
249
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
250

Guolin Ke's avatar
Guolin Ke committed
251
      }
252

253
      # Loop through each iteration
254
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
255
256
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
257

258
      return(invisible(self))
259

Guolin Ke's avatar
Guolin Ke committed
260
    },
261

262
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
263
    rollback_one_iter = function() {
264

265
266
      self$restore_handle()

267
268
      .Call(
        LGBM_BoosterRollbackOneIter_R
269
270
        , private$handle
      )
271

272
      # Loop through each iteration
273
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
274
275
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
276

277
      return(invisible(self))
278

Guolin Ke's avatar
Guolin Ke committed
279
    },
280

281
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
282
    current_iter = function() {
283

284
285
      self$restore_handle()

286
      cur_iter <- 0L
287
288
289
290
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
291
      )
292
      return(cur_iter)
293

Guolin Ke's avatar
Guolin Ke committed
294
    },
295

296
    # Get upper bound
297
    upper_bound = function() {
298

299
300
      self$restore_handle()

301
      upper_bound <- 0.0
302
303
304
305
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
306
      )
307
      return(upper_bound)
308
309
310
311

    },

    # Get lower bound
312
    lower_bound = function() {
313

314
315
      self$restore_handle()

316
      lower_bound <- 0.0
317
318
319
320
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
321
      )
322
      return(lower_bound)
323
324
325

    },

326
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
327
    eval = function(data, name, feval = NULL) {
328

329
      if (!lgb.is.Dataset(data)) {
330
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
331
      }
332

333
      # Check for identical data
334
      data_idx <- 0L
335
      if (identical(data, private$train_set)) {
336
        data_idx <- 1L
337
      } else {
338

339
        # Check for validation data
340
        if (length(private$valid_sets) > 0L) {
341

342
          for (i in seq_along(private$valid_sets)) {
343

344
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
345
            if (identical(data, private$valid_sets[[i]])) {
346

347
              # Found identical data, skip
348
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
349
              break
350

Guolin Ke's avatar
Guolin Ke committed
351
            }
352

Guolin Ke's avatar
Guolin Ke committed
353
          }
354

Guolin Ke's avatar
Guolin Ke committed
355
        }
356

Guolin Ke's avatar
Guolin Ke committed
357
      }
358

359
      # Check if evaluation was not done
360
      if (data_idx == 0L) {
361

362
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
363
364
        self$add_valid(data, name)
        data_idx <- private$num_dataset
365

Guolin Ke's avatar
Guolin Ke committed
366
      }
367

368
      # Evaluate data
369
370
371
372
373
374
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
375
      )
376

Guolin Ke's avatar
Guolin Ke committed
377
    },
378

379
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
380
    eval_train = function(feval = NULL) {
381
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
382
    },
383

384
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
385
    eval_valid = function(feval = NULL) {
386

387
      ret <- list()
388

389
      if (length(private$valid_sets) <= 0L) {
390
391
        return(ret)
      }
392

393
      for (i in seq_along(private$valid_sets)) {
394
395
        ret <- append(
          x = ret
396
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
397
        )
Guolin Ke's avatar
Guolin Ke committed
398
      }
399

400
      return(ret)
401

Guolin Ke's avatar
Guolin Ke committed
402
    },
403

404
    # Save model
405
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
406

407
408
      self$restore_handle()

409
410
411
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
412

413
414
      filename <- path.expand(filename)

415
416
      .Call(
        LGBM_BoosterSaveModel_R
417
418
        , private$handle
        , as.integer(num_iteration)
419
        , as.integer(feature_importance_type)
420
        , filename
421
      )
422

423
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
424
    },
425

426
427
428
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {

      self$restore_handle()
429

430
431
432
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
433

434
      model_str <- .Call(
435
          LGBM_BoosterSaveModelToString_R
436
437
438
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
439
440
      )

441
442
443
444
      if (as_char) {
        model_str <- rawToChar(model_str)
      }

445
      return(model_str)
446

447
    },
448

449
    # Dump model in memory
450
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
451

452
453
      self$restore_handle()

454
455
456
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
457

458
      model_str <- .Call(
459
460
461
462
463
464
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
      )

465
      return(model_str)
466

Guolin Ke's avatar
Guolin Ke committed
467
    },
468

469
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
470
    predict = function(data,
471
                       start_iteration = NULL,
472
473
474
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
475
                       predcontrib = FALSE,
476
                       header = FALSE,
477
                       params = list()) {
478

479
480
      self$restore_handle()

481
482
483
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
484

485
486
487
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
488

489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
      # possibly override keyword arguments with parameters
      #
      # NOTE: this length() check minimizes the latency introduced by these checks,
      #       for the common case where params is empty
      #
      # NOTE: doing this here instead of in Predictor$predict() to keep
      #       Predictor$predict() as fast as possible
      if (length(params) > 0L) {
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_raw_score"
          , params = params
          , alternative_kwarg_value = rawscore
        )
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_leaf_index"
          , params = params
          , alternative_kwarg_value = predleaf
        )
        params <- lgb.check.wrapper_param(
          main_param_name = "predict_contrib"
          , params = params
          , alternative_kwarg_value = predcontrib
        )
        rawscore <- params[["predict_raw_score"]]
        predleaf <- params[["predict_leaf_index"]]
        predcontrib <- params[["predict_contrib"]]
      }

517
      # Predict on new data
518
519
520
521
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
      )
522
523
      return(
        predictor$predict(
524
525
526
527
528
529
530
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
531
        )
532
      )
533

534
    },
535

536
537
    # Transform into predictor
    to_predictor = function() {
538
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
539
    },
540

541
542
    # Used for serialization
    raw = NULL,
543

544
545
546
547
548
549
    # Store serialized raw bytes in model object
    save_raw = function() {
      if (is.null(self$raw)) {
        self$raw <- self$save_model_to_string(NULL, as_char = FALSE)
      }
      return(invisible(NULL))
550

551
    },
552

553
554
    drop_raw = function() {
      self$raw <- NULL
555
      return(invisible(NULL))
556
    },
557

558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
    check_null_handle = function() {
      return(lgb.is.null.handle(private$handle))
    },

    restore_handle = function() {
      if (self$check_null_handle()) {
        if (is.null(self$raw)) {
          .Call(LGBM_NullBoosterHandleError_R)
        }
        private$handle <- .Call(LGBM_BoosterLoadModelFromString_R, self$raw)
      }
      return(invisible(NULL))
    },

    get_handle = function() {
      return(private$handle)
574
    }
575

Guolin Ke's avatar
Guolin Ke committed
576
577
  ),
  private = list(
578
579
580
581
582
583
584
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
585
586
    num_class = 1L,
    num_dataset = 0L,
587
588
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
589
    higher_better_inner_eval = NULL,
590
    set_objective_to_none = FALSE,
591
    train_set_version = 0L,
592
593
    # Predict data
    inner_predict = function(idx) {
594

595
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
596
      data_name <- private$name_train_set
597

598
599
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
600
      }
601

602
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
603
604
605
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
606

607
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
608
      if (is.null(private$predict_buffer[[data_name]])) {
609

610
        # Store predictions
611
        npred <- 0L
612
613
        .Call(
          LGBM_BoosterGetNumPredict_R
614
          , private$handle
615
          , as.integer(idx - 1L)
616
          , npred
617
        )
618
        private$predict_buffer[[data_name]] <- numeric(npred)
619

Guolin Ke's avatar
Guolin Ke committed
620
      }
621

622
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
623
      if (!private$is_predicted_cur_iter[[idx]]) {
624

625
        # Use buffer
626
627
        .Call(
          LGBM_BoosterGetPredict_R
628
          , private$handle
629
          , as.integer(idx - 1L)
630
          , private$predict_buffer[[data_name]]
631
        )
Guolin Ke's avatar
Guolin Ke committed
632
633
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
634

635
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
636
    },
637

638
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
639
    get_eval_info = function() {
640

Guolin Ke's avatar
Guolin Ke committed
641
      if (is.null(private$eval_names)) {
642
        eval_names <- .Call(
643
          LGBM_BoosterGetEvalNames_R
644
645
          , private$handle
        )
646

647
        if (length(eval_names) > 0L) {
648

649
          # Parse and store privately names
650
          private$eval_names <- eval_names
651
652
653

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
654
          metric_names <- gsub("@.*", "", eval_names)
655
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
656

Guolin Ke's avatar
Guolin Ke committed
657
        }
658

Guolin Ke's avatar
Guolin Ke committed
659
      }
660

661
      return(private$eval_names)
662

Guolin Ke's avatar
Guolin Ke committed
663
    },
664

Guolin Ke's avatar
Guolin Ke committed
665
    inner_eval = function(data_name, data_idx, feval = NULL) {
666

667
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
668
669
670
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
671

672
673
      self$restore_handle()

Guolin Ke's avatar
Guolin Ke committed
674
      private$get_eval_info()
675

Guolin Ke's avatar
Guolin Ke committed
676
      ret <- list()
677

678
      if (length(private$eval_names) > 0L) {
679

680
681
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
682
683
        .Call(
          LGBM_BoosterGetEval_R
684
          , private$handle
685
          , as.integer(data_idx - 1L)
686
          , tmp_vals
687
        )
688

689
        for (i in seq_along(private$eval_names)) {
690

691
692
693
694
695
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
696
          res$higher_better <- private$higher_better_inner_eval[i]
697
          ret <- append(ret, list(res))
698

Guolin Ke's avatar
Guolin Ke committed
699
        }
700

Guolin Ke's avatar
Guolin Ke committed
701
      }
702

703
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
704
      if (!is.null(feval)) {
705

706
        # Check if evaluation metric is a function
707
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
708
709
          stop("lgb.Booster.eval: feval should be a function")
        }
710

Guolin Ke's avatar
Guolin Ke committed
711
        data <- private$train_set
712

713
        # Check if data to assess is existing differently
714
715
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
716
        }
717

718
        # Perform function evaluation
719
        res <- feval(private$inner_predict(data_idx), data)
720

721
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
722
          stop("lgb.Booster.eval: custom eval function should return a
723
724
            list with attribute (name, value, higher_better)");
        }
725

726
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
727
        res$data_name <- data_name
728
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
729
      }
730

731
      return(ret)
732

Guolin Ke's avatar
Guolin Ke committed
733
    }
734

Guolin Ke's avatar
Guolin Ke committed
735
736
737
  )
)

738
739
740
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
741
#' @param object Object of class \code{lgb.Booster}
742
743
#' @param newdata a \code{matrix} object, a \code{dgCMatrix} object or
#'                a character representing a path to a text file (CSV, TSV, or LibSVM)
744
745
746
747
748
749
750
751
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
752
#' @param rawscore whether the prediction should be returned in the for of original untransformed
753
754
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
755
#' @param predleaf whether predict leaf index instead.
756
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
757
#' @param header only used for prediction for text file. True if text file has header
758
759
760
#' @param params a list of additional named parameters. See
#'               \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#'               the "Predict Parameters" section of the documentation} for a list of parameters and
761
762
#'               valid values. Where these conflict with the values of keyword arguments to this function,
#'               the values in \code{params} take precedence.
763
#' @param ... ignored
764
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
765
#'         For multiclass classification, it returns a matrix of dimensions \code{(nrows(data), num_class)}.
766
#'
767
768
#'         When passing \code{predleaf=TRUE} or \code{predcontrib=TRUE}, the output will always be
#'         returned as a matrix.
769
#'
Guolin Ke's avatar
Guolin Ke committed
770
#' @examples
771
#' \donttest{
772
773
774
775
776
777
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
778
779
780
781
782
783
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
784
#' valids <- list(test = dtest)
785
786
787
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
788
#'   , nrounds = 5L
789
790
#'   , valids = valids
#' )
791
#' preds <- predict(model, test$data)
792
793
#'
#' # pass other prediction parameters
794
#' preds <- predict(
795
796
797
798
799
800
#'     model,
#'     test$data,
#'     params = list(
#'         predict_disable_shape_check = TRUE
#'    )
#' )
801
#' }
802
#' @importFrom utils modifyList
Guolin Ke's avatar
Guolin Ke committed
803
#' @export
James Lamb's avatar
James Lamb committed
804
predict.lgb.Booster <- function(object,
805
                                newdata,
806
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
807
808
809
810
811
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
812
                                params = list(),
James Lamb's avatar
James Lamb committed
813
                                ...) {
814

815
  if (!lgb.is.Booster(x = object)) {
816
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
817
  }
818

819
820
  additional_params <- list(...)
  if (length(additional_params) > 0L) {
821
822
823
    if ("reshape" %in% names(additional_params)) {
      stop("'reshape' argument is no longer supported.")
    }
824
825
826
    warning(paste0(
      "predict.lgb.Booster: Found the following passed through '...': "
      , paste(names(additional_params), collapse = ", ")
827
      , ". These are ignored. Use argument 'params' instead."
828
829
830
    ))
  }

831
832
  return(
    object$predict(
833
      data = newdata
834
835
836
837
838
839
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
840
      , params = params
841
    )
842
  )
Guolin Ke's avatar
Guolin Ke committed
843
844
}

845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `x`, returned as invisible.
#' @export
print.lgb.Booster <- function(x, ...) {
  # nolint start
  handle <- x$.__enclos_env__$private$handle
  handle_is_null <- lgb.is.null.handle(handle)

  if (!handle_is_null) {
    ntrees <- x$current_iter()
    if (ntrees == 1L) {
      cat("LightGBM Model (1 tree)\n")
    } else {
      cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
    }
  } else {
    cat("LightGBM Model\n")
  }

  if (!handle_is_null) {
    obj <- x$params$objective
    if (obj == "none") {
      obj <- "custom"
    }
873
874
    num_class <- x$.__enclos_env__$private$num_class
    if (num_class == 1L) {
875
876
877
878
      cat(sprintf("Objective: %s\n", obj))
    } else {
      cat(sprintf("Objective: %s (%d classes)\n"
          , obj
879
          , num_class))
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
    }
  } else {
    cat("(Booster handle is invalid)\n")
  }

  if (!handle_is_null) {
    ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
    cat(sprintf("Fitted to dataset with %d columns\n", ncols))
  }
  # nolint end

  return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `object`, returned as invisible.
#' @export
summary.lgb.Booster <- function(object, ...) {
  print(object)
}

905
906
#' @name lgb.load
#' @title Load LightGBM model
907
908
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
909
#' @param filename path of model file
910
#' @param model_str a str containing the model (as a `character` or `raw` vector)
911
#'
912
#' @return lgb.Booster
913
#'
Guolin Ke's avatar
Guolin Ke committed
914
#' @examples
915
#' \donttest{
916
917
918
919
920
921
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
922
923
924
925
926
927
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
928
#' valids <- list(test = dtest)
929
930
931
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
932
#'   , nrounds = 5L
933
#'   , valids = valids
934
#'   , early_stopping_rounds = 3L
935
#' )
936
937
938
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
939
940
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
941
#' }
Guolin Ke's avatar
Guolin Ke committed
942
#' @export
943
lgb.load <- function(filename = NULL, model_str = NULL) {
944

945
946
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
947

948
949
950
951
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
952
    filename <- path.expand(filename)
953
954
955
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
956
957
    return(invisible(Booster$new(modelfile = filename)))
  }
958

959
  if (model_str_provided) {
960
961
    if (!is.raw(model_str) && !is.character(model_str)) {
      stop("lgb.load: model_str should be a character/raw vector")
962
    }
963
964
    return(invisible(Booster$new(model_str = model_str)))
  }
965

966
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
967
968
}

969
970
971
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
972
973
974
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
975
#'
976
#' @return lgb.Booster
977
#'
Guolin Ke's avatar
Guolin Ke committed
978
#' @examples
979
#' \donttest{
980
981
982
983
984
985
986
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
987
988
989
990
991
992
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
993
#' valids <- list(test = dtest)
994
995
996
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
997
#'   , nrounds = 10L
998
#'   , valids = valids
999
#'   , early_stopping_rounds = 5L
1000
#' )
1001
#' lgb.save(model, tempfile(fileext = ".txt"))
1002
#' }
Guolin Ke's avatar
Guolin Ke committed
1003
#' @export
1004
lgb.save <- function(booster, filename, num_iteration = NULL) {
1005

1006
  if (!lgb.is.Booster(x = booster)) {
1007
1008
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1009

1010
1011
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
1012
  }
1013
  filename <- path.expand(filename)
1014

1015
  # Store booster
1016
1017
1018
1019
1020
1021
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
1022

Guolin Ke's avatar
Guolin Ke committed
1023
1024
}

1025
1026
1027
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
1028
1029
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1030
#'
Guolin Ke's avatar
Guolin Ke committed
1031
#' @return json format of model
1032
#'
Guolin Ke's avatar
Guolin Ke committed
1033
#' @examples
1034
#' \donttest{
1035
1036
1037
1038
1039
1040
1041
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1042
1043
1044
1045
1046
1047
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1048
#' valids <- list(test = dtest)
1049
1050
1051
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1052
#'   , nrounds = 10L
1053
#'   , valids = valids
1054
#'   , early_stopping_rounds = 5L
1055
#' )
1056
#' json_model <- lgb.dump(model)
1057
#' }
Guolin Ke's avatar
Guolin Ke committed
1058
#' @export
1059
lgb.dump <- function(booster, num_iteration = NULL) {
1060

1061
  if (!lgb.is.Booster(x = booster)) {
1062
1063
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1064

1065
  # Return booster at requested iteration
1066
  return(booster$dump_model(num_iteration =  num_iteration))
1067

Guolin Ke's avatar
Guolin Ke committed
1068
1069
}

1070
1071
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
1072
1073
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
1074
#' @param booster Object of class \code{lgb.Booster}
1075
1076
1077
1078
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
1079
#' @param is_err TRUE will return evaluation error instead
1080
#'
1081
#' @return numeric vector of evaluation result
1082
#'
1083
#' @examples
1084
#' \donttest{
1085
#' # train a regression model
1086
1087
1088
1089
1090
1091
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1092
1093
1094
1095
1096
1097
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1098
#' valids <- list(test = dtest)
1099
1100
1101
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1102
#'   , nrounds = 5L
1103
1104
#'   , valids = valids
#' )
1105
1106
1107
1108
1109
1110
1111
1112
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1113
#' lgb.get.eval.result(model, "test", "l2")
1114
#' }
Guolin Ke's avatar
Guolin Ke committed
1115
#' @export
1116
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1117

1118
  if (!lgb.is.Booster(x = booster)) {
1119
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1120
  }
1121

1122
1123
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1124
  }
1125

1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1136
  }
1137

1138
  # Check if evaluation result is existing
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1150
1151
    stop("lgb.get.eval.result: wrong eval name")
  }
1152

1153
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1154

1155
  # Check if error is requested
1156
  if (is_err) {
1157
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1158
  }
1159

1160
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1161
1162
    return(as.numeric(result))
  }
1163

1164
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1165
  iters <- as.integer(iters)
1166
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1167
  iters <- iters - delta
1168

1169
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1170
}