lgb.Booster.R 30 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom R6 R6Class
2
#' @importFrom utils modifyList
James Lamb's avatar
James Lamb committed
3
Booster <- R6::R6Class(
4
  classname = "lgb.Booster",
5
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
6
  public = list(
7

8
    best_iter = -1L,
9
    best_score = NA_real_,
10
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
11
    record_evals = list(),
12

13
14
    # Finalize will free up the handles
    finalize = function() {
15
16
17
18
19
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
20
      return(invisible(NULL))
21
    },
22

23
24
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
25
26
                          train_set = NULL,
                          modelfile = NULL,
27
                          model_str = NULL) {
28

29
      # Create parameters and handle
30
      handle <- NULL
31

32
33
      # Attempts to create a handle for the dataset
      try({
34

35
36
37
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
38
          if (!lgb.is.Dataset(train_set)) {
39
40
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
41
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
42
          params <- utils::modifyList(params, train_set$get_params())
43
          params_str <- lgb.params2str(params = params)
44
          # Store booster handle
45
          handle <- .Call(
46
            LGBM_BoosterCreate_R
47
            , train_set_handle
48
49
            , params_str
          )
50

51
52
          # Create private booster information
          private$train_set <- train_set
53
          private$train_set_version <- train_set$.__enclos_env__$private$version
54
          private$num_dataset <- 1L
55
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
56

57
58
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
59

60
            # Merge booster
61
62
            .Call(
              LGBM_BoosterMerge_R
63
64
65
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
66

67
          }
68

69
70
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
71

72
        } else if (!is.null(modelfile)) {
73

74
75
76
77
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
78

79
          # Create booster from model
80
          handle <- .Call(
81
            LGBM_BoosterCreateFromModelfile_R
82
            , modelfile
83
          )
84

85
        } else if (!is.null(model_str)) {
86

87
          # Do we have a model_str as character?
88
89
90
          if (!is.character(model_str)) {
            stop("lgb.Booster: Can only use a string as model_str")
          }
91

92
          # Create booster from model
93
          handle <- .Call(
94
            LGBM_BoosterLoadModelFromString_R
95
            , model_str
96
          )
97

98
        } else {
99

100
          # Booster non existent
101
102
103
104
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
105

106
        }
107

108
      })
109

110
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
111
      if (isTRUE(lgb.is.null.handle(x = handle))) {
112

Guolin Ke's avatar
Guolin Ke committed
113
        stop("lgb.Booster: cannot create Booster handle")
114

Guolin Ke's avatar
Guolin Ke committed
115
      } else {
116

Guolin Ke's avatar
Guolin Ke committed
117
118
119
120
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
121
122
        .Call(
          LGBM_BoosterGetNumClasses_R
123
          , private$handle
124
          , private$num_class
125
        )
126

Guolin Ke's avatar
Guolin Ke committed
127
      }
128

129
130
      self$params <- params

131
132
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
133
    },
134

135
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
136
    set_train_data_name = function(name) {
137

138
      # Set name
Guolin Ke's avatar
Guolin Ke committed
139
      private$name_train_set <- name
140
      return(invisible(self))
141

Guolin Ke's avatar
Guolin Ke committed
142
    },
143

144
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
145
    add_valid = function(data, name) {
146

147
      if (!lgb.is.Dataset(data)) {
148
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
149
      }
150

Guolin Ke's avatar
Guolin Ke committed
151
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
152
153
154
155
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
156
      }
157

158
159
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
160
      }
161

162
      # Add validation data to booster
163
164
      .Call(
        LGBM_BoosterAddValidData_R
165
166
167
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
168

169
170
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
171
      private$num_dataset <- private$num_dataset + 1L
172
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
173

174
      return(invisible(self))
175

Guolin Ke's avatar
Guolin Ke committed
176
    },
177

Guolin Ke's avatar
Guolin Ke committed
178
    reset_parameter = function(params, ...) {
179

180
181
182
183
184
185
186
187
188
189
      additional_params <- list(...)
      if (length(additional_params) > 0L) {
        warning(paste0(
          "Booster$reset_parameter(): Found the following passed through '...': "
          , paste(names(additional_params), collapse = ", ")
          , ". These will be used, but in future releases of lightgbm, this warning will become an error. "
          , "Add these to 'params' instead."
        ))
      }

190
      if (methods::is(self$params, "list")) {
191
        params <- utils::modifyList(self$params, params)
192
193
      }

194
      params <- utils::modifyList(params, additional_params)
195
      params_str <- lgb.params2str(params = params)
196

197
198
      .Call(
        LGBM_BoosterResetParameter_R
199
200
201
        , private$handle
        , params_str
      )
202
      self$params <- params
203

204
      return(invisible(self))
205

Guolin Ke's avatar
Guolin Ke committed
206
    },
207

208
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
209
    update = function(train_set = NULL, fobj = NULL) {
210

211
212
213
214
215
216
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

Guolin Ke's avatar
Guolin Ke committed
217
      if (!is.null(train_set)) {
218

219
        if (!lgb.is.Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
220
221
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
222

Guolin Ke's avatar
Guolin Ke committed
223
        if (!identical(train_set$predictor, private$init_predictor)) {
224
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
225
        }
226

227
228
        .Call(
          LGBM_BoosterResetTrainingData_R
229
230
231
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
232

233
        private$train_set <- train_set
234
        private$train_set_version <- train_set$.__enclos_env__$private$version
235

Guolin Ke's avatar
Guolin Ke committed
236
      }
237

238
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
239
      if (is.null(fobj)) {
240
241
242
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
243
        # Boost iteration from known objective
244
245
        .Call(
          LGBM_BoosterUpdateOneIter_R
246
247
          , private$handle
        )
248

Guolin Ke's avatar
Guolin Ke committed
249
      } else {
250

251
252
253
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
254
        if (!private$set_objective_to_none) {
255
          self$reset_parameter(params = list(objective = "none"))
256
          private$set_objective_to_none <- TRUE
257
        }
258
        # Perform objective calculation
259
        gpair <- fobj(private$inner_predict(1L), private$train_set)
260

261
        # Check for gradient and hessian as list
262
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
263
          stop("lgb.Booster.update: custom objective should
264
265
            return a list with attributes (hess, grad)")
        }
266

267
        # Return custom boosting gradient/hessian
268
269
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
270
271
272
273
274
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
275

Guolin Ke's avatar
Guolin Ke committed
276
      }
277

278
      # Loop through each iteration
279
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
280
281
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
282

283
      return(invisible(self))
284

Guolin Ke's avatar
Guolin Ke committed
285
    },
286

287
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
288
    rollback_one_iter = function() {
289

290
291
      .Call(
        LGBM_BoosterRollbackOneIter_R
292
293
        , private$handle
      )
294

295
      # Loop through each iteration
296
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
297
298
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
299

300
      return(invisible(self))
301

Guolin Ke's avatar
Guolin Ke committed
302
    },
303

304
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
305
    current_iter = function() {
306

307
      cur_iter <- 0L
308
309
310
311
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
312
      )
313
      return(cur_iter)
314

Guolin Ke's avatar
Guolin Ke committed
315
    },
316

317
    # Get upper bound
318
    upper_bound = function() {
319

320
      upper_bound <- 0.0
321
322
323
324
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
325
      )
326
      return(upper_bound)
327
328
329
330

    },

    # Get lower bound
331
    lower_bound = function() {
332

333
      lower_bound <- 0.0
334
335
336
337
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
338
      )
339
      return(lower_bound)
340
341
342

    },

343
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
344
    eval = function(data, name, feval = NULL) {
345

346
      if (!lgb.is.Dataset(data)) {
347
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
348
      }
349

350
      # Check for identical data
351
      data_idx <- 0L
352
      if (identical(data, private$train_set)) {
353
        data_idx <- 1L
354
      } else {
355

356
        # Check for validation data
357
        if (length(private$valid_sets) > 0L) {
358

359
          for (i in seq_along(private$valid_sets)) {
360

361
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
362
            if (identical(data, private$valid_sets[[i]])) {
363

364
              # Found identical data, skip
365
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
366
              break
367

Guolin Ke's avatar
Guolin Ke committed
368
            }
369

Guolin Ke's avatar
Guolin Ke committed
370
          }
371

Guolin Ke's avatar
Guolin Ke committed
372
        }
373

Guolin Ke's avatar
Guolin Ke committed
374
      }
375

376
      # Check if evaluation was not done
377
      if (data_idx == 0L) {
378

379
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
380
381
        self$add_valid(data, name)
        data_idx <- private$num_dataset
382

Guolin Ke's avatar
Guolin Ke committed
383
      }
384

385
      # Evaluate data
386
387
388
389
390
391
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
392
      )
393

Guolin Ke's avatar
Guolin Ke committed
394
    },
395

396
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
397
    eval_train = function(feval = NULL) {
398
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
399
    },
400

401
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
402
    eval_valid = function(feval = NULL) {
403

404
      ret <- list()
405

406
      if (length(private$valid_sets) <= 0L) {
407
408
        return(ret)
      }
409

410
      for (i in seq_along(private$valid_sets)) {
411
412
        ret <- append(
          x = ret
413
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
414
        )
Guolin Ke's avatar
Guolin Ke committed
415
      }
416

417
      return(ret)
418

Guolin Ke's avatar
Guolin Ke committed
419
    },
420

421
    # Save model
422
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
423

424
425
426
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
427

428
429
      .Call(
        LGBM_BoosterSaveModel_R
430
431
        , private$handle
        , as.integer(num_iteration)
432
        , as.integer(feature_importance_type)
433
        , filename
434
      )
435

436
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
437
    },
438

439
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
440

441
442
443
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
444

445
      model_str <- .Call(
446
          LGBM_BoosterSaveModelToString_R
447
448
449
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
450
451
      )

452
      return(model_str)
453

454
    },
455

456
    # Dump model in memory
457
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
458

459
460
461
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
462

463
      model_str <- .Call(
464
465
466
467
468
469
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
      )

470
      return(model_str)
471

Guolin Ke's avatar
Guolin Ke committed
472
    },
473

474
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
475
    predict = function(data,
476
                       start_iteration = NULL,
477
478
479
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
480
                       predcontrib = FALSE,
481
                       header = FALSE,
482
                       reshape = FALSE,
483
                       params = list(),
484
                       ...) {
485

486
487
488
489
490
491
492
493
494
495
      additional_params <- list(...)
      if (length(additional_params) > 0L) {
        warning(paste0(
          "Booster$predict(): Found the following passed through '...': "
          , paste(names(additional_params), collapse = ", ")
          , ". These will be used, but in future releases of lightgbm, this warning will become an error. "
          , "Add these to 'params' instead. See ?predict.lgb.Booster for documentation on how to call this function."
        ))
      }

496
497
498
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
499

500
501
502
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
503

504
      # Predict on new data
505
      params <- utils::modifyList(params, additional_params)
506
507
508
509
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
      )
510
511
      return(
        predictor$predict(
512
513
514
515
516
517
518
519
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
          , reshape = reshape
520
        )
521
      )
522

523
    },
524

525
526
    # Transform into predictor
    to_predictor = function() {
527
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
528
    },
529

530
    # Used for save
531
    raw = NA,
532

533
    # Save model to temporary file for in-memory saving
534
    save = function() {
535

536
      # Overwrite model in object
537
      self$raw <- self$save_model_to_string(NULL)
538

539
540
      return(invisible(NULL))

541
    }
542

Guolin Ke's avatar
Guolin Ke committed
543
544
  ),
  private = list(
545
546
547
548
549
550
551
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
552
553
    num_class = 1L,
    num_dataset = 0L,
554
555
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
556
    higher_better_inner_eval = NULL,
557
    set_objective_to_none = FALSE,
558
    train_set_version = 0L,
559
560
    # Predict data
    inner_predict = function(idx) {
561

562
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
563
      data_name <- private$name_train_set
564

565
566
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
567
      }
568

569
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
570
571
572
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
573

574
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
575
      if (is.null(private$predict_buffer[[data_name]])) {
576

577
        # Store predictions
578
        npred <- 0L
579
580
        .Call(
          LGBM_BoosterGetNumPredict_R
581
          , private$handle
582
          , as.integer(idx - 1L)
583
          , npred
584
        )
585
        private$predict_buffer[[data_name]] <- numeric(npred)
586

Guolin Ke's avatar
Guolin Ke committed
587
      }
588

589
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
590
      if (!private$is_predicted_cur_iter[[idx]]) {
591

592
        # Use buffer
593
594
        .Call(
          LGBM_BoosterGetPredict_R
595
          , private$handle
596
          , as.integer(idx - 1L)
597
          , private$predict_buffer[[data_name]]
598
        )
Guolin Ke's avatar
Guolin Ke committed
599
600
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
601

602
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
603
    },
604

605
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
606
    get_eval_info = function() {
607

Guolin Ke's avatar
Guolin Ke committed
608
      if (is.null(private$eval_names)) {
609
        eval_names <- .Call(
610
          LGBM_BoosterGetEvalNames_R
611
612
          , private$handle
        )
613

614
        if (length(eval_names) > 0L) {
615

616
          # Parse and store privately names
617
          private$eval_names <- eval_names
618
619
620

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
621
          metric_names <- gsub("@.*", "", eval_names)
622
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
623

Guolin Ke's avatar
Guolin Ke committed
624
        }
625

Guolin Ke's avatar
Guolin Ke committed
626
      }
627

628
      return(private$eval_names)
629

Guolin Ke's avatar
Guolin Ke committed
630
    },
631

Guolin Ke's avatar
Guolin Ke committed
632
    inner_eval = function(data_name, data_idx, feval = NULL) {
633

634
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
635
636
637
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
638

Guolin Ke's avatar
Guolin Ke committed
639
      private$get_eval_info()
640

Guolin Ke's avatar
Guolin Ke committed
641
      ret <- list()
642

643
      if (length(private$eval_names) > 0L) {
644

645
646
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
647
648
        .Call(
          LGBM_BoosterGetEval_R
649
          , private$handle
650
          , as.integer(data_idx - 1L)
651
          , tmp_vals
652
        )
653

654
        for (i in seq_along(private$eval_names)) {
655

656
657
658
659
660
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
661
          res$higher_better <- private$higher_better_inner_eval[i]
662
          ret <- append(ret, list(res))
663

Guolin Ke's avatar
Guolin Ke committed
664
        }
665

Guolin Ke's avatar
Guolin Ke committed
666
      }
667

668
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
669
      if (!is.null(feval)) {
670

671
        # Check if evaluation metric is a function
672
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
673
674
          stop("lgb.Booster.eval: feval should be a function")
        }
675

Guolin Ke's avatar
Guolin Ke committed
676
        data <- private$train_set
677

678
        # Check if data to assess is existing differently
679
680
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
681
        }
682

683
        # Perform function evaluation
684
        res <- feval(private$inner_predict(data_idx), data)
685

686
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
687
          stop("lgb.Booster.eval: custom eval function should return a
688
689
            list with attribute (name, value, higher_better)");
        }
690

691
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
692
        res$data_name <- data_name
693
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
694
      }
695

696
      return(ret)
697

Guolin Ke's avatar
Guolin Ke committed
698
    }
699

Guolin Ke's avatar
Guolin Ke committed
700
701
702
  )
)

703
704
705
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
706
#' @param object Object of class \code{lgb.Booster}
707
708
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or
#'             a character representing a path to a text file (CSV, TSV, or LibSVM)
709
710
711
712
713
714
715
716
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
717
#' @param rawscore whether the prediction should be returned in the for of original untransformed
718
719
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
720
#' @param predleaf whether predict leaf index instead.
721
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
722
#' @param header only used for prediction for text file. True if text file has header
723
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
724
#'                prediction outputs per case.
725
726
727
728
729
#' @param params a list of additional named parameters. See
#'               \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#'               the "Predict Parameters" section of the documentation} for a list of parameters and
#'               valid values.
#' @param ... Additional prediction parameters. NOTE: deprecated as of v3.3.0. Use \code{params} instead.
730
731
732
733
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
734
#'
735
736
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
737
#'
Guolin Ke's avatar
Guolin Ke committed
738
#' @examples
739
#' \donttest{
740
741
742
743
744
745
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
746
747
748
749
750
751
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
752
#' valids <- list(test = dtest)
753
754
755
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
756
#'   , nrounds = 5L
757
758
#'   , valids = valids
#' )
759
#' preds <- predict(model, test$data)
760
761
#'
#' # pass other prediction parameters
762
#' preds <- predict(
763
764
765
766
767
768
#'     model,
#'     test$data,
#'     params = list(
#'         predict_disable_shape_check = TRUE
#'    )
#' )
769
#' }
770
#' @importFrom utils modifyList
Guolin Ke's avatar
Guolin Ke committed
771
#' @export
James Lamb's avatar
James Lamb committed
772
773
predict.lgb.Booster <- function(object,
                                data,
774
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
775
776
777
778
779
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
780
                                reshape = FALSE,
781
                                params = list(),
James Lamb's avatar
James Lamb committed
782
                                ...) {
783

784
  if (!lgb.is.Booster(x = object)) {
785
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
786
  }
787

788
789
790
791
792
793
794
795
796
797
  additional_params <- list(...)
  if (length(additional_params) > 0L) {
    warning(paste0(
      "predict.lgb.Booster: Found the following passed through '...': "
      , paste(names(additional_params), collapse = ", ")
      , ". These will be used, but in future releases of lightgbm, this warning will become an error. "
      , "Add these to 'params' instead. See ?predict.lgb.Booster for documentation on how to call this function."
    ))
  }

798
799
800
  return(
    object$predict(
      data = data
801
802
803
804
805
806
807
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
      , reshape = reshape
808
      , params = utils::modifyList(params, additional_params)
809
    )
810
  )
Guolin Ke's avatar
Guolin Ke committed
811
812
}

813
814
#' @name lgb.load
#' @title Load LightGBM model
815
816
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
817
#' @param filename path of model file
818
#' @param model_str a str containing the model
819
#'
820
#' @return lgb.Booster
821
#'
Guolin Ke's avatar
Guolin Ke committed
822
#' @examples
823
#' \donttest{
824
825
826
827
828
829
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
830
831
832
833
834
835
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
836
#' valids <- list(test = dtest)
837
838
839
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
840
#'   , nrounds = 5L
841
#'   , valids = valids
842
#'   , early_stopping_rounds = 3L
843
#' )
844
845
846
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
847
848
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
849
#' }
Guolin Ke's avatar
Guolin Ke committed
850
#' @export
851
lgb.load <- function(filename = NULL, model_str = NULL) {
852

853
854
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
855

856
857
858
859
860
861
862
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
863
864
    return(invisible(Booster$new(modelfile = filename)))
  }
865

866
867
868
869
  if (model_str_provided) {
    if (!is.character(model_str)) {
      stop("lgb.load: model_str should be character")
    }
870
871
    return(invisible(Booster$new(model_str = model_str)))
  }
872

873
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
874
875
}

876
877
878
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
879
880
881
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
882
#'
883
#' @return lgb.Booster
884
#'
Guolin Ke's avatar
Guolin Ke committed
885
#' @examples
886
#' \donttest{
887
888
889
890
891
892
893
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
894
895
896
897
898
899
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
900
#' valids <- list(test = dtest)
901
902
903
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
904
#'   , nrounds = 10L
905
#'   , valids = valids
906
#'   , early_stopping_rounds = 5L
907
#' )
908
#' lgb.save(model, tempfile(fileext = ".txt"))
909
#' }
Guolin Ke's avatar
Guolin Ke committed
910
#' @export
911
lgb.save <- function(booster, filename, num_iteration = NULL) {
912

913
  if (!lgb.is.Booster(x = booster)) {
914
915
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
916

917
918
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
919
  }
920

921
  # Store booster
922
923
924
925
926
927
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
928

Guolin Ke's avatar
Guolin Ke committed
929
930
}

931
932
933
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
934
935
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
936
#'
Guolin Ke's avatar
Guolin Ke committed
937
#' @return json format of model
938
#'
Guolin Ke's avatar
Guolin Ke committed
939
#' @examples
940
#' \donttest{
941
942
943
944
945
946
947
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
948
949
950
951
952
953
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
954
#' valids <- list(test = dtest)
955
956
957
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
958
#'   , nrounds = 10L
959
#'   , valids = valids
960
#'   , early_stopping_rounds = 5L
961
#' )
962
#' json_model <- lgb.dump(model)
963
#' }
Guolin Ke's avatar
Guolin Ke committed
964
#' @export
965
lgb.dump <- function(booster, num_iteration = NULL) {
966

967
  if (!lgb.is.Booster(x = booster)) {
968
969
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
970

971
  # Return booster at requested iteration
972
  return(booster$dump_model(num_iteration =  num_iteration))
973

Guolin Ke's avatar
Guolin Ke committed
974
975
}

976
977
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
978
979
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
980
#' @param booster Object of class \code{lgb.Booster}
981
982
983
984
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
985
#' @param is_err TRUE will return evaluation error instead
986
#'
987
#' @return numeric vector of evaluation result
988
#'
989
#' @examples
990
#' \donttest{
991
#' # train a regression model
992
993
994
995
996
997
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
998
999
1000
1001
1002
1003
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1004
#' valids <- list(test = dtest)
1005
1006
1007
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1008
#'   , nrounds = 5L
1009
1010
#'   , valids = valids
#' )
1011
1012
1013
1014
1015
1016
1017
1018
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1019
#' lgb.get.eval.result(model, "test", "l2")
1020
#' }
Guolin Ke's avatar
Guolin Ke committed
1021
#' @export
1022
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1023

1024
  if (!lgb.is.Booster(x = booster)) {
1025
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1026
  }
1027

1028
1029
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1030
  }
1031

1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1042
  }
1043

1044
  # Check if evaluation result is existing
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1056
1057
    stop("lgb.get.eval.result: wrong eval name")
  }
1058

1059
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1060

1061
  # Check if error is requested
1062
  if (is_err) {
1063
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1064
  }
1065

1066
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1067
1068
    return(as.numeric(result))
  }
1069

1070
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1071
  iters <- as.integer(iters)
1072
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1073
  iters <- iters - delta
1074

1075
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1076
}