lgb.Booster.R 32.4 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom R6 R6Class
2
#' @importFrom utils modifyList
James Lamb's avatar
James Lamb committed
3
Booster <- R6::R6Class(
4
  classname = "lgb.Booster",
5
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
6
  public = list(
7

8
    best_iter = -1L,
9
    best_score = NA_real_,
10
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
11
    record_evals = list(),
12

13
14
    # Finalize will free up the handles
    finalize = function() {
15
16
17
18
19
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
20
      return(invisible(NULL))
21
    },
22

23
24
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
25
26
                          train_set = NULL,
                          modelfile = NULL,
27
                          model_str = NULL) {
28

29
      # Create parameters and handle
30
      handle <- NULL
31

32
33
      # Attempts to create a handle for the dataset
      try({
34

35
36
37
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
38
          if (!lgb.is.Dataset(train_set)) {
39
40
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
41
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
42
          params <- utils::modifyList(params, train_set$get_params())
43
          params_str <- lgb.params2str(params = params)
44
          # Store booster handle
45
          handle <- .Call(
46
            LGBM_BoosterCreate_R
47
            , train_set_handle
48
49
            , params_str
          )
50

51
52
          # Create private booster information
          private$train_set <- train_set
53
          private$train_set_version <- train_set$.__enclos_env__$private$version
54
          private$num_dataset <- 1L
55
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
56

57
58
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
59

60
            # Merge booster
61
62
            .Call(
              LGBM_BoosterMerge_R
63
64
65
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
66

67
          }
68

69
70
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
71

72
        } else if (!is.null(modelfile)) {
73

74
75
76
77
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
78

79
80
          modelfile <- path.expand(modelfile)

81
          # Create booster from model
82
          handle <- .Call(
83
            LGBM_BoosterCreateFromModelfile_R
84
            , modelfile
85
          )
86

87
        } else if (!is.null(model_str)) {
88

89
90
91
          # Do we have a model_str as character/raw?
          if (!is.raw(model_str) && !is.character(model_str)) {
            stop("lgb.Booster: Can only use a character/raw vector as model_str")
92
          }
93

94
          # Create booster from model
95
          handle <- .Call(
96
            LGBM_BoosterLoadModelFromString_R
97
            , model_str
98
          )
99

100
        } else {
101

102
          # Booster non existent
103
104
105
106
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
107

108
        }
109

110
      })
111

112
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
113
      if (isTRUE(lgb.is.null.handle(x = handle))) {
114

Guolin Ke's avatar
Guolin Ke committed
115
        stop("lgb.Booster: cannot create Booster handle")
116

Guolin Ke's avatar
Guolin Ke committed
117
      } else {
118

Guolin Ke's avatar
Guolin Ke committed
119
120
121
122
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
123
124
        .Call(
          LGBM_BoosterGetNumClasses_R
125
          , private$handle
126
          , private$num_class
127
        )
128

Guolin Ke's avatar
Guolin Ke committed
129
      }
130

131
132
      self$params <- params

133
134
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
135
    },
136

137
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
138
    set_train_data_name = function(name) {
139

140
      # Set name
Guolin Ke's avatar
Guolin Ke committed
141
      private$name_train_set <- name
142
      return(invisible(self))
143

Guolin Ke's avatar
Guolin Ke committed
144
    },
145

146
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
147
    add_valid = function(data, name) {
148

149
      if (!lgb.is.Dataset(data)) {
150
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
151
      }
152

Guolin Ke's avatar
Guolin Ke committed
153
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
154
155
156
157
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
158
      }
159

160
161
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
162
      }
163

164
      # Add validation data to booster
165
166
      .Call(
        LGBM_BoosterAddValidData_R
167
168
169
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
170

171
172
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
173
      private$num_dataset <- private$num_dataset + 1L
174
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
175

176
      return(invisible(self))
177

Guolin Ke's avatar
Guolin Ke committed
178
    },
179

180
    reset_parameter = function(params) {
181

182
      if (methods::is(self$params, "list")) {
183
        params <- utils::modifyList(self$params, params)
184
185
      }

186
      params_str <- lgb.params2str(params = params)
187

188
189
      self$restore_handle()

190
191
      .Call(
        LGBM_BoosterResetParameter_R
192
193
194
        , private$handle
        , params_str
      )
195
      self$params <- params
196

197
      return(invisible(self))
198

Guolin Ke's avatar
Guolin Ke committed
199
    },
200

201
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
202
    update = function(train_set = NULL, fobj = NULL) {
203

204
205
206
207
208
209
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

Guolin Ke's avatar
Guolin Ke committed
210
      if (!is.null(train_set)) {
211

212
        if (!lgb.is.Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
213
214
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
215

Guolin Ke's avatar
Guolin Ke committed
216
        if (!identical(train_set$predictor, private$init_predictor)) {
217
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
218
        }
219

220
221
        .Call(
          LGBM_BoosterResetTrainingData_R
222
223
224
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
225

226
        private$train_set <- train_set
227
        private$train_set_version <- train_set$.__enclos_env__$private$version
228

Guolin Ke's avatar
Guolin Ke committed
229
      }
230

231
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
232
      if (is.null(fobj)) {
233
234
235
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
236
        # Boost iteration from known objective
237
238
        .Call(
          LGBM_BoosterUpdateOneIter_R
239
240
          , private$handle
        )
241

Guolin Ke's avatar
Guolin Ke committed
242
      } else {
243

244
245
246
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
247
        if (!private$set_objective_to_none) {
248
          self$reset_parameter(params = list(objective = "none"))
249
          private$set_objective_to_none <- TRUE
250
        }
251
        # Perform objective calculation
252
        gpair <- fobj(private$inner_predict(1L), private$train_set)
253

254
        # Check for gradient and hessian as list
255
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
256
          stop("lgb.Booster.update: custom objective should
257
258
            return a list with attributes (hess, grad)")
        }
259

260
        # Return custom boosting gradient/hessian
261
262
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
263
264
265
266
267
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
268

Guolin Ke's avatar
Guolin Ke committed
269
      }
270

271
      # Loop through each iteration
272
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
273
274
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
275

276
      return(invisible(self))
277

Guolin Ke's avatar
Guolin Ke committed
278
    },
279

280
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
281
    rollback_one_iter = function() {
282

283
284
      self$restore_handle()

285
286
      .Call(
        LGBM_BoosterRollbackOneIter_R
287
288
        , private$handle
      )
289

290
      # Loop through each iteration
291
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
292
293
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
294

295
      return(invisible(self))
296

Guolin Ke's avatar
Guolin Ke committed
297
    },
298

299
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
300
    current_iter = function() {
301

302
303
      self$restore_handle()

304
      cur_iter <- 0L
305
306
307
308
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
309
      )
310
      return(cur_iter)
311

Guolin Ke's avatar
Guolin Ke committed
312
    },
313

314
    # Get upper bound
315
    upper_bound = function() {
316

317
318
      self$restore_handle()

319
      upper_bound <- 0.0
320
321
322
323
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
324
      )
325
      return(upper_bound)
326
327
328
329

    },

    # Get lower bound
330
    lower_bound = function() {
331

332
333
      self$restore_handle()

334
      lower_bound <- 0.0
335
336
337
338
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
339
      )
340
      return(lower_bound)
341
342
343

    },

344
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
345
    eval = function(data, name, feval = NULL) {
346

347
      if (!lgb.is.Dataset(data)) {
348
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
349
      }
350

351
      # Check for identical data
352
      data_idx <- 0L
353
      if (identical(data, private$train_set)) {
354
        data_idx <- 1L
355
      } else {
356

357
        # Check for validation data
358
        if (length(private$valid_sets) > 0L) {
359

360
          for (i in seq_along(private$valid_sets)) {
361

362
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
363
            if (identical(data, private$valid_sets[[i]])) {
364

365
              # Found identical data, skip
366
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
367
              break
368

Guolin Ke's avatar
Guolin Ke committed
369
            }
370

Guolin Ke's avatar
Guolin Ke committed
371
          }
372

Guolin Ke's avatar
Guolin Ke committed
373
        }
374

Guolin Ke's avatar
Guolin Ke committed
375
      }
376

377
      # Check if evaluation was not done
378
      if (data_idx == 0L) {
379

380
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
381
382
        self$add_valid(data, name)
        data_idx <- private$num_dataset
383

Guolin Ke's avatar
Guolin Ke committed
384
      }
385

386
      # Evaluate data
387
388
389
390
391
392
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
393
      )
394

Guolin Ke's avatar
Guolin Ke committed
395
    },
396

397
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
398
    eval_train = function(feval = NULL) {
399
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
400
    },
401

402
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
403
    eval_valid = function(feval = NULL) {
404

405
      ret <- list()
406

407
      if (length(private$valid_sets) <= 0L) {
408
409
        return(ret)
      }
410

411
      for (i in seq_along(private$valid_sets)) {
412
413
        ret <- append(
          x = ret
414
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
415
        )
Guolin Ke's avatar
Guolin Ke committed
416
      }
417

418
      return(ret)
419

Guolin Ke's avatar
Guolin Ke committed
420
    },
421

422
    # Save model
423
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
424

425
426
      self$restore_handle()

427
428
429
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
430

431
432
      filename <- path.expand(filename)

433
434
      .Call(
        LGBM_BoosterSaveModel_R
435
436
        , private$handle
        , as.integer(num_iteration)
437
        , as.integer(feature_importance_type)
438
        , filename
439
      )
440

441
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
442
    },
443

444
445
446
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {

      self$restore_handle()
447

448
449
450
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
451

452
      model_str <- .Call(
453
          LGBM_BoosterSaveModelToString_R
454
455
456
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
457
458
      )

459
460
461
462
      if (as_char) {
        model_str <- rawToChar(model_str)
      }

463
      return(model_str)
464

465
    },
466

467
    # Dump model in memory
468
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
469

470
471
      self$restore_handle()

472
473
474
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
475

476
      model_str <- .Call(
477
478
479
480
481
482
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
      )

483
      return(model_str)
484

Guolin Ke's avatar
Guolin Ke committed
485
    },
486

487
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
488
    predict = function(data,
489
                       start_iteration = NULL,
490
491
492
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
493
                       predcontrib = FALSE,
494
                       header = FALSE,
495
                       reshape = FALSE,
496
                       params = list(),
497
                       ...) {
498

499
500
      self$restore_handle()

501
502
503
504
505
506
507
508
509
510
      additional_params <- list(...)
      if (length(additional_params) > 0L) {
        warning(paste0(
          "Booster$predict(): Found the following passed through '...': "
          , paste(names(additional_params), collapse = ", ")
          , ". These will be used, but in future releases of lightgbm, this warning will become an error. "
          , "Add these to 'params' instead. See ?predict.lgb.Booster for documentation on how to call this function."
        ))
      }

511
512
513
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
514

515
516
517
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
518

519
      # Predict on new data
520
      params <- utils::modifyList(params, additional_params)
521
522
523
524
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
      )
525
526
      return(
        predictor$predict(
527
528
529
530
531
532
533
534
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
          , reshape = reshape
535
        )
536
      )
537

538
    },
539

540
541
    # Transform into predictor
    to_predictor = function() {
542
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
543
    },
544

545
546
    # Used for serialization
    raw = NULL,
547

548
549
550
551
552
553
    # Store serialized raw bytes in model object
    save_raw = function() {
      if (is.null(self$raw)) {
        self$raw <- self$save_model_to_string(NULL, as_char = FALSE)
      }
      return(invisible(NULL))
554

555
    },
556

557
558
    drop_raw = function() {
      self$raw <- NULL
559
      return(invisible(NULL))
560
    },
561

562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
    check_null_handle = function() {
      return(lgb.is.null.handle(private$handle))
    },

    restore_handle = function() {
      if (self$check_null_handle()) {
        if (is.null(self$raw)) {
          .Call(LGBM_NullBoosterHandleError_R)
        }
        private$handle <- .Call(LGBM_BoosterLoadModelFromString_R, self$raw)
      }
      return(invisible(NULL))
    },

    get_handle = function() {
      return(private$handle)
578
    }
579

Guolin Ke's avatar
Guolin Ke committed
580
581
  ),
  private = list(
582
583
584
585
586
587
588
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
589
590
    num_class = 1L,
    num_dataset = 0L,
591
592
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
593
    higher_better_inner_eval = NULL,
594
    set_objective_to_none = FALSE,
595
    train_set_version = 0L,
596
597
    # Predict data
    inner_predict = function(idx) {
598

599
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
600
      data_name <- private$name_train_set
601

602
603
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
604
      }
605

606
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
607
608
609
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
610

611
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
612
      if (is.null(private$predict_buffer[[data_name]])) {
613

614
        # Store predictions
615
        npred <- 0L
616
617
        .Call(
          LGBM_BoosterGetNumPredict_R
618
          , private$handle
619
          , as.integer(idx - 1L)
620
          , npred
621
        )
622
        private$predict_buffer[[data_name]] <- numeric(npred)
623

Guolin Ke's avatar
Guolin Ke committed
624
      }
625

626
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
627
      if (!private$is_predicted_cur_iter[[idx]]) {
628

629
        # Use buffer
630
631
        .Call(
          LGBM_BoosterGetPredict_R
632
          , private$handle
633
          , as.integer(idx - 1L)
634
          , private$predict_buffer[[data_name]]
635
        )
Guolin Ke's avatar
Guolin Ke committed
636
637
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
638

639
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
640
    },
641

642
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
643
    get_eval_info = function() {
644

Guolin Ke's avatar
Guolin Ke committed
645
      if (is.null(private$eval_names)) {
646
        eval_names <- .Call(
647
          LGBM_BoosterGetEvalNames_R
648
649
          , private$handle
        )
650

651
        if (length(eval_names) > 0L) {
652

653
          # Parse and store privately names
654
          private$eval_names <- eval_names
655
656
657

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
658
          metric_names <- gsub("@.*", "", eval_names)
659
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
660

Guolin Ke's avatar
Guolin Ke committed
661
        }
662

Guolin Ke's avatar
Guolin Ke committed
663
      }
664

665
      return(private$eval_names)
666

Guolin Ke's avatar
Guolin Ke committed
667
    },
668

Guolin Ke's avatar
Guolin Ke committed
669
    inner_eval = function(data_name, data_idx, feval = NULL) {
670

671
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
672
673
674
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
675

676
677
      self$restore_handle()

Guolin Ke's avatar
Guolin Ke committed
678
      private$get_eval_info()
679

Guolin Ke's avatar
Guolin Ke committed
680
      ret <- list()
681

682
      if (length(private$eval_names) > 0L) {
683

684
685
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
686
687
        .Call(
          LGBM_BoosterGetEval_R
688
          , private$handle
689
          , as.integer(data_idx - 1L)
690
          , tmp_vals
691
        )
692

693
        for (i in seq_along(private$eval_names)) {
694

695
696
697
698
699
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
700
          res$higher_better <- private$higher_better_inner_eval[i]
701
          ret <- append(ret, list(res))
702

Guolin Ke's avatar
Guolin Ke committed
703
        }
704

Guolin Ke's avatar
Guolin Ke committed
705
      }
706

707
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
708
      if (!is.null(feval)) {
709

710
        # Check if evaluation metric is a function
711
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
712
713
          stop("lgb.Booster.eval: feval should be a function")
        }
714

Guolin Ke's avatar
Guolin Ke committed
715
        data <- private$train_set
716

717
        # Check if data to assess is existing differently
718
719
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
720
        }
721

722
        # Perform function evaluation
723
        res <- feval(private$inner_predict(data_idx), data)
724

725
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
726
          stop("lgb.Booster.eval: custom eval function should return a
727
728
            list with attribute (name, value, higher_better)");
        }
729

730
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
731
        res$data_name <- data_name
732
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
733
      }
734

735
      return(ret)
736

Guolin Ke's avatar
Guolin Ke committed
737
    }
738

Guolin Ke's avatar
Guolin Ke committed
739
740
741
  )
)

742
743
744
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
745
#' @param object Object of class \code{lgb.Booster}
746
747
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or
#'             a character representing a path to a text file (CSV, TSV, or LibSVM)
748
749
750
751
752
753
754
755
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
756
#' @param rawscore whether the prediction should be returned in the for of original untransformed
757
758
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
759
#' @param predleaf whether predict leaf index instead.
760
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
761
#' @param header only used for prediction for text file. True if text file has header
762
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
763
#'                prediction outputs per case.
764
765
766
767
768
#' @param params a list of additional named parameters. See
#'               \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#'               the "Predict Parameters" section of the documentation} for a list of parameters and
#'               valid values.
#' @param ... Additional prediction parameters. NOTE: deprecated as of v3.3.0. Use \code{params} instead.
769
770
771
772
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
773
#'
774
775
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
776
#'
Guolin Ke's avatar
Guolin Ke committed
777
#' @examples
778
#' \donttest{
779
780
781
782
783
784
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
785
786
787
788
789
790
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
791
#' valids <- list(test = dtest)
792
793
794
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
795
#'   , nrounds = 5L
796
797
#'   , valids = valids
#' )
798
#' preds <- predict(model, test$data)
799
800
#'
#' # pass other prediction parameters
801
#' preds <- predict(
802
803
804
805
806
807
#'     model,
#'     test$data,
#'     params = list(
#'         predict_disable_shape_check = TRUE
#'    )
#' )
808
#' }
809
#' @importFrom utils modifyList
Guolin Ke's avatar
Guolin Ke committed
810
#' @export
James Lamb's avatar
James Lamb committed
811
812
predict.lgb.Booster <- function(object,
                                data,
813
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
814
815
816
817
818
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
819
                                reshape = FALSE,
820
                                params = list(),
James Lamb's avatar
James Lamb committed
821
                                ...) {
822

823
  if (!lgb.is.Booster(x = object)) {
824
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
825
  }
826

827
828
829
830
831
832
833
834
835
836
  additional_params <- list(...)
  if (length(additional_params) > 0L) {
    warning(paste0(
      "predict.lgb.Booster: Found the following passed through '...': "
      , paste(names(additional_params), collapse = ", ")
      , ". These will be used, but in future releases of lightgbm, this warning will become an error. "
      , "Add these to 'params' instead. See ?predict.lgb.Booster for documentation on how to call this function."
    ))
  }

837
838
839
  return(
    object$predict(
      data = data
840
841
842
843
844
845
846
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
      , reshape = reshape
847
      , params = utils::modifyList(params, additional_params)
848
    )
849
  )
Guolin Ke's avatar
Guolin Ke committed
850
851
}

852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `x`, returned as invisible.
#' @export
print.lgb.Booster <- function(x, ...) {
  # nolint start
  handle <- x$.__enclos_env__$private$handle
  handle_is_null <- lgb.is.null.handle(handle)

  if (!handle_is_null) {
    ntrees <- x$current_iter()
    if (ntrees == 1L) {
      cat("LightGBM Model (1 tree)\n")
    } else {
      cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
    }
  } else {
    cat("LightGBM Model\n")
  }

  if (!handle_is_null) {
    obj <- x$params$objective
    if (obj == "none") {
      obj <- "custom"
    }
    if (x$.__enclos_env__$private$num_class == 1L) {
      cat(sprintf("Objective: %s\n", obj))
    } else {
      cat(sprintf("Objective: %s (%d classes)\n"
          , obj
          , x$.__enclos_env__$private$num_class))
    }
  } else {
    cat("(Booster handle is invalid)\n")
  }

  if (!handle_is_null) {
    ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
    cat(sprintf("Fitted to dataset with %d columns\n", ncols))
  }
  # nolint end

  return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `object`, returned as invisible.
#' @export
summary.lgb.Booster <- function(object, ...) {
  print(object)
}

911
912
#' @name lgb.load
#' @title Load LightGBM model
913
914
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
915
#' @param filename path of model file
916
#' @param model_str a str containing the model (as a `character` or `raw` vector)
917
#'
918
#' @return lgb.Booster
919
#'
Guolin Ke's avatar
Guolin Ke committed
920
#' @examples
921
#' \donttest{
922
923
924
925
926
927
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
928
929
930
931
932
933
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
934
#' valids <- list(test = dtest)
935
936
937
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
938
#'   , nrounds = 5L
939
#'   , valids = valids
940
#'   , early_stopping_rounds = 3L
941
#' )
942
943
944
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
945
946
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
947
#' }
Guolin Ke's avatar
Guolin Ke committed
948
#' @export
949
lgb.load <- function(filename = NULL, model_str = NULL) {
950

951
952
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
953

954
955
956
957
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
958
    filename <- path.expand(filename)
959
960
961
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
962
963
    return(invisible(Booster$new(modelfile = filename)))
  }
964

965
  if (model_str_provided) {
966
967
    if (!is.raw(model_str) && !is.character(model_str)) {
      stop("lgb.load: model_str should be a character/raw vector")
968
    }
969
970
    return(invisible(Booster$new(model_str = model_str)))
  }
971

972
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
973
974
}

975
976
977
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
978
979
980
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
981
#'
982
#' @return lgb.Booster
983
#'
Guolin Ke's avatar
Guolin Ke committed
984
#' @examples
985
#' \donttest{
986
987
988
989
990
991
992
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
993
994
995
996
997
998
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
999
#' valids <- list(test = dtest)
1000
1001
1002
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1003
#'   , nrounds = 10L
1004
#'   , valids = valids
1005
#'   , early_stopping_rounds = 5L
1006
#' )
1007
#' lgb.save(model, tempfile(fileext = ".txt"))
1008
#' }
Guolin Ke's avatar
Guolin Ke committed
1009
#' @export
1010
lgb.save <- function(booster, filename, num_iteration = NULL) {
1011

1012
  if (!lgb.is.Booster(x = booster)) {
1013
1014
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1015

1016
1017
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
1018
  }
1019
  filename <- path.expand(filename)
1020

1021
  # Store booster
1022
1023
1024
1025
1026
1027
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
1028

Guolin Ke's avatar
Guolin Ke committed
1029
1030
}

1031
1032
1033
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
1034
1035
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1036
#'
Guolin Ke's avatar
Guolin Ke committed
1037
#' @return json format of model
1038
#'
Guolin Ke's avatar
Guolin Ke committed
1039
#' @examples
1040
#' \donttest{
1041
1042
1043
1044
1045
1046
1047
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1048
1049
1050
1051
1052
1053
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1054
#' valids <- list(test = dtest)
1055
1056
1057
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1058
#'   , nrounds = 10L
1059
#'   , valids = valids
1060
#'   , early_stopping_rounds = 5L
1061
#' )
1062
#' json_model <- lgb.dump(model)
1063
#' }
Guolin Ke's avatar
Guolin Ke committed
1064
#' @export
1065
lgb.dump <- function(booster, num_iteration = NULL) {
1066

1067
  if (!lgb.is.Booster(x = booster)) {
1068
1069
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1070

1071
  # Return booster at requested iteration
1072
  return(booster$dump_model(num_iteration =  num_iteration))
1073

Guolin Ke's avatar
Guolin Ke committed
1074
1075
}

1076
1077
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
1078
1079
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
1080
#' @param booster Object of class \code{lgb.Booster}
1081
1082
1083
1084
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
1085
#' @param is_err TRUE will return evaluation error instead
1086
#'
1087
#' @return numeric vector of evaluation result
1088
#'
1089
#' @examples
1090
#' \donttest{
1091
#' # train a regression model
1092
1093
1094
1095
1096
1097
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1098
1099
1100
1101
1102
1103
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1104
#' valids <- list(test = dtest)
1105
1106
1107
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1108
#'   , nrounds = 5L
1109
1110
#'   , valids = valids
#' )
1111
1112
1113
1114
1115
1116
1117
1118
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1119
#' lgb.get.eval.result(model, "test", "l2")
1120
#' }
Guolin Ke's avatar
Guolin Ke committed
1121
#' @export
1122
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1123

1124
  if (!lgb.is.Booster(x = booster)) {
1125
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1126
  }
1127

1128
1129
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1130
  }
1131

1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1142
  }
1143

1144
  # Check if evaluation result is existing
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1156
1157
    stop("lgb.get.eval.result: wrong eval name")
  }
1158

1159
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1160

1161
  # Check if error is requested
1162
  if (is_err) {
1163
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1164
  }
1165

1166
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1167
1168
    return(as.numeric(result))
  }
1169

1170
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1171
  iters <- as.integer(iters)
1172
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1173
  iters <- iters - delta
1174

1175
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1176
}