lgb.Booster.R 29.5 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
#' @importFrom R6 R6Class
Booster <- R6::R6Class(
3
  classname = "lgb.Booster",
4
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
5
  public = list(
6

7
    best_iter = -1L,
8
    best_score = NA_real_,
9
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
10
    record_evals = list(),
11

12
13
    # Finalize will free up the handles
    finalize = function() {
14
15
16
17
18
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
19
      return(invisible(NULL))
20
    },
21

22
23
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
24
25
                          train_set = NULL,
                          modelfile = NULL,
26
                          model_str = NULL,
Guolin Ke's avatar
Guolin Ke committed
27
                          ...) {
28

29
30
      # Create parameters and handle
      params <- append(params, list(...))
31
      handle <- NULL
32

33
34
      # Attempts to create a handle for the dataset
      try({
35

36
37
38
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
39
          if (!lgb.is.Dataset(train_set)) {
40
41
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
42
43
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
          params <- modifyList(params, train_set$get_params())
44
          params_str <- lgb.params2str(params = params)
45
          # Store booster handle
46
          handle <- .Call(
47
            LGBM_BoosterCreate_R
48
            , train_set_handle
49
50
            , params_str
          )
51

52
53
          # Create private booster information
          private$train_set <- train_set
54
          private$train_set_version <- train_set$.__enclos_env__$private$version
55
          private$num_dataset <- 1L
56
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
57

58
59
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
60

61
            # Merge booster
62
63
            .Call(
              LGBM_BoosterMerge_R
64
65
66
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
67

68
          }
69

70
71
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
72

73
        } else if (!is.null(modelfile)) {
74

75
76
77
78
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
79

80
          # Create booster from model
81
          handle <- .Call(
82
            LGBM_BoosterCreateFromModelfile_R
83
            , modelfile
84
          )
85

86
        } else if (!is.null(model_str)) {
87

88
          # Do we have a model_str as character?
89
90
91
          if (!is.character(model_str)) {
            stop("lgb.Booster: Can only use a string as model_str")
          }
92

93
          # Create booster from model
94
          handle <- .Call(
95
            LGBM_BoosterLoadModelFromString_R
96
            , model_str
97
          )
98

99
        } else {
100

101
          # Booster non existent
102
103
104
105
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
106

107
        }
108

109
      })
110

111
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
112
      if (isTRUE(lgb.is.null.handle(x = handle))) {
113

Guolin Ke's avatar
Guolin Ke committed
114
        stop("lgb.Booster: cannot create Booster handle")
115

Guolin Ke's avatar
Guolin Ke committed
116
      } else {
117

Guolin Ke's avatar
Guolin Ke committed
118
119
120
121
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
122
123
        .Call(
          LGBM_BoosterGetNumClasses_R
124
          , private$handle
125
          , private$num_class
126
        )
127

Guolin Ke's avatar
Guolin Ke committed
128
      }
129

130
131
      self$params <- params

132
133
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
134
    },
135

136
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
137
    set_train_data_name = function(name) {
138

139
      # Set name
Guolin Ke's avatar
Guolin Ke committed
140
      private$name_train_set <- name
141
      return(invisible(self))
142

Guolin Ke's avatar
Guolin Ke committed
143
    },
144

145
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
146
    add_valid = function(data, name) {
147

148
      # Check if data is lgb.Dataset
149
      if (!lgb.is.Dataset(data)) {
150
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
151
      }
152

153
      # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
154
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
155
156
157
158
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
159
      }
160

161
      # Check if names are character
162
163
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
164
      }
165

166
      # Add validation data to booster
167
168
      .Call(
        LGBM_BoosterAddValidData_R
169
170
171
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
172

173
174
175
      # Store private information
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
176
      private$num_dataset <- private$num_dataset + 1L
177
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
178

179
      return(invisible(self))
180

Guolin Ke's avatar
Guolin Ke committed
181
    },
182

183
    # Reset parameters of booster
Guolin Ke's avatar
Guolin Ke committed
184
    reset_parameter = function(params, ...) {
185

186
187
188
189
190
      if (methods::is(self$params, "list")) {
        params <- modifyList(self$params, params)
      }

      params <- modifyList(params, list(...))
191
      params_str <- lgb.params2str(params = params)
192

193
194
      .Call(
        LGBM_BoosterResetParameter_R
195
196
197
        , private$handle
        , params_str
      )
198
      self$params <- params
199

200
      return(invisible(self))
201

Guolin Ke's avatar
Guolin Ke committed
202
    },
203

204
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
205
    update = function(train_set = NULL, fobj = NULL) {
206

207
208
209
210
211
212
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

213
      # Check if training set is not null
Guolin Ke's avatar
Guolin Ke committed
214
      if (!is.null(train_set)) {
215

216
        # Check if training set is lgb.Dataset
217
        if (!lgb.is.Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
218
219
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
220

221
        # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
222
        if (!identical(train_set$predictor, private$init_predictor)) {
223
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
224
        }
225

226
        # Reset training data on booster
227
228
        .Call(
          LGBM_BoosterResetTrainingData_R
229
230
231
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
232

233
        # Store private train set
234
        private$train_set <- train_set
235
        private$train_set_version <- train_set$.__enclos_env__$private$version
236

Guolin Ke's avatar
Guolin Ke committed
237
      }
238

239
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
240
      if (is.null(fobj)) {
241
242
243
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
244
        # Boost iteration from known objective
245
246
        .Call(
          LGBM_BoosterUpdateOneIter_R
247
248
          , private$handle
        )
249

Guolin Ke's avatar
Guolin Ke committed
250
      } else {
251

252
253
254
255
        # Check if objective is function
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
256
        if (!private$set_objective_to_none) {
257
          self$reset_parameter(params = list(objective = "none"))
258
          private$set_objective_to_none <- TRUE
259
        }
260
        # Perform objective calculation
261
        gpair <- fobj(private$inner_predict(1L), private$train_set)
262

263
        # Check for gradient and hessian as list
264
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
265
          stop("lgb.Booster.update: custom objective should
266
267
            return a list with attributes (hess, grad)")
        }
268

269
        # Return custom boosting gradient/hessian
270
271
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
272
273
274
275
276
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
277

Guolin Ke's avatar
Guolin Ke committed
278
      }
279

280
      # Loop through each iteration
281
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
282
283
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
284

285
      return(invisible(self))
286

Guolin Ke's avatar
Guolin Ke committed
287
    },
288

289
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
290
    rollback_one_iter = function() {
291

292
      # Return one iteration behind
293
294
      .Call(
        LGBM_BoosterRollbackOneIter_R
295
296
        , private$handle
      )
297

298
      # Loop through each iteration
299
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
300
301
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
302

303
      return(invisible(self))
304

Guolin Ke's avatar
Guolin Ke committed
305
    },
306

307
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
308
    current_iter = function() {
309

310
      cur_iter <- 0L
311
312
313
314
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
315
      )
316
      return(cur_iter)
317

Guolin Ke's avatar
Guolin Ke committed
318
    },
319

320
    # Get upper bound
321
    upper_bound = function() {
322

323
      upper_bound <- 0.0
324
325
326
327
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
328
      )
329
      return(upper_bound)
330
331
332
333

    },

    # Get lower bound
334
    lower_bound = function() {
335

336
      lower_bound <- 0.0
337
338
339
340
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
341
      )
342
      return(lower_bound)
343
344
345

    },

346
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
347
    eval = function(data, name, feval = NULL) {
348

349
      # Check if dataset is lgb.Dataset
350
      if (!lgb.is.Dataset(data)) {
351
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
352
      }
353

354
      # Check for identical data
355
      data_idx <- 0L
356
      if (identical(data, private$train_set)) {
357
        data_idx <- 1L
358
      } else {
359

360
        # Check for validation data
361
        if (length(private$valid_sets) > 0L) {
362

363
          # Loop through each validation set
364
          for (i in seq_along(private$valid_sets)) {
365

366
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
367
            if (identical(data, private$valid_sets[[i]])) {
368

369
              # Found identical data, skip
370
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
371
              break
372

Guolin Ke's avatar
Guolin Ke committed
373
            }
374

Guolin Ke's avatar
Guolin Ke committed
375
          }
376

Guolin Ke's avatar
Guolin Ke committed
377
        }
378

Guolin Ke's avatar
Guolin Ke committed
379
      }
380

381
      # Check if evaluation was not done
382
      if (data_idx == 0L) {
383

384
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
385
386
        self$add_valid(data, name)
        data_idx <- private$num_dataset
387

Guolin Ke's avatar
Guolin Ke committed
388
      }
389

390
      # Evaluate data
391
392
393
394
395
396
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
397
      )
398

Guolin Ke's avatar
Guolin Ke committed
399
    },
400

401
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
402
    eval_train = function(feval = NULL) {
403
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
404
    },
405

406
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
407
    eval_valid = function(feval = NULL) {
408

409
      # Create ret list
410
      ret <- list()
411

412
      # Check if validation is empty
413
      if (length(private$valid_sets) <= 0L) {
414
415
        return(ret)
      }
416

417
      # Loop through each validation set
418
      for (i in seq_along(private$valid_sets)) {
419
420
        ret <- append(
          x = ret
421
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
422
        )
Guolin Ke's avatar
Guolin Ke committed
423
      }
424

425
      return(ret)
426

Guolin Ke's avatar
Guolin Ke committed
427
    },
428

429
    # Save model
430
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
431

432
433
434
435
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
436

437
      # Save booster model
438
439
      .Call(
        LGBM_BoosterSaveModel_R
440
441
        , private$handle
        , as.integer(num_iteration)
442
        , as.integer(feature_importance_type)
443
        , filename
444
      )
445

446
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
447
    },
448

449
    # Save model to string
450
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
451

452
453
454
455
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
456

457
      model_str <- .Call(
458
          LGBM_BoosterSaveModelToString_R
459
460
461
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
462
463
      )

464
      return(model_str)
465

466
    },
467

468
    # Dump model in memory
469
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
470

471
472
473
474
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
475

476
      model_str <- .Call(
477
478
479
480
481
482
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
      )

483
      return(model_str)
484

Guolin Ke's avatar
Guolin Ke committed
485
    },
486

487
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
488
    predict = function(data,
489
                       start_iteration = NULL,
490
491
492
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
493
                       predcontrib = FALSE,
494
                       header = FALSE,
495
496
                       reshape = FALSE,
                       ...) {
497

498
      # Check if number of iteration is non existent
499
500
501
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
502
      # Check if start iteration is non existent
503
504
505
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
506

507
      # Predict on new data
508
509
510
511
512
      params <- list(...)
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
      )
513
514
      return(
        predictor$predict(
515
516
517
518
519
520
521
522
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
          , reshape = reshape
523
        )
524
      )
525

526
    },
527

528
529
    # Transform into predictor
    to_predictor = function() {
530
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
531
    },
532

533
    # Used for save
534
    raw = NA,
535

536
    # Save model to temporary file for in-memory saving
537
    save = function() {
538

539
      # Overwrite model in object
540
      self$raw <- self$save_model_to_string(NULL)
541

542
543
      return(invisible(NULL))

544
    }
545

Guolin Ke's avatar
Guolin Ke committed
546
547
  ),
  private = list(
548
549
550
551
552
553
554
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
555
556
    num_class = 1L,
    num_dataset = 0L,
557
558
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
559
    higher_better_inner_eval = NULL,
560
    set_objective_to_none = FALSE,
561
    train_set_version = 0L,
562
563
    # Predict data
    inner_predict = function(idx) {
564

565
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
566
      data_name <- private$name_train_set
567

568
      # Check for id bigger than 1
569
570
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
571
      }
572

573
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
574
575
576
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
577

578
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
579
      if (is.null(private$predict_buffer[[data_name]])) {
580

581
        # Store predictions
582
        npred <- 0L
583
584
        .Call(
          LGBM_BoosterGetNumPredict_R
585
          , private$handle
586
          , as.integer(idx - 1L)
587
          , npred
588
        )
589
        private$predict_buffer[[data_name]] <- numeric(npred)
590

Guolin Ke's avatar
Guolin Ke committed
591
      }
592

593
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
594
      if (!private$is_predicted_cur_iter[[idx]]) {
595

596
        # Use buffer
597
598
        .Call(
          LGBM_BoosterGetPredict_R
599
          , private$handle
600
          , as.integer(idx - 1L)
601
          , private$predict_buffer[[data_name]]
602
        )
Guolin Ke's avatar
Guolin Ke committed
603
604
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
605

606
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
607
    },
608

609
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
610
    get_eval_info = function() {
611

612
      # Check for evaluation names emptiness
Guolin Ke's avatar
Guolin Ke committed
613
      if (is.null(private$eval_names)) {
614
        eval_names <- .Call(
615
          LGBM_BoosterGetEvalNames_R
616
617
          , private$handle
        )
618

619
        # Check names' length
620
        if (length(eval_names) > 0L) {
621

622
          # Parse and store privately names
623
          private$eval_names <- eval_names
624
625
626

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
627
          metric_names <- gsub("@.*", "", eval_names)
628
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
629

Guolin Ke's avatar
Guolin Ke committed
630
        }
631

Guolin Ke's avatar
Guolin Ke committed
632
      }
633

634
      return(private$eval_names)
635

Guolin Ke's avatar
Guolin Ke committed
636
    },
637

638
    # Perform inner evaluation
Guolin Ke's avatar
Guolin Ke committed
639
    inner_eval = function(data_name, data_idx, feval = NULL) {
640

641
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
642
643
644
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
645

646
      # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
647
      private$get_eval_info()
648

649
      # Prepare return
Guolin Ke's avatar
Guolin Ke committed
650
      ret <- list()
651

652
      # Check evaluation names existence
653
      if (length(private$eval_names) > 0L) {
654

655
656
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
657
658
        .Call(
          LGBM_BoosterGetEval_R
659
          , private$handle
660
          , as.integer(data_idx - 1L)
661
          , tmp_vals
662
        )
663

664
        # Loop through all evaluation names
665
        for (i in seq_along(private$eval_names)) {
666

667
668
669
670
671
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
672
          res$higher_better <- private$higher_better_inner_eval[i]
673
          ret <- append(ret, list(res))
674

Guolin Ke's avatar
Guolin Ke committed
675
        }
676

Guolin Ke's avatar
Guolin Ke committed
677
      }
678

679
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
680
      if (!is.null(feval)) {
681

682
        # Check if evaluation metric is a function
683
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
684
685
          stop("lgb.Booster.eval: feval should be a function")
        }
686

687
        # Prepare data
Guolin Ke's avatar
Guolin Ke committed
688
        data <- private$train_set
689

690
        # Check if data to assess is existing differently
691
692
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
693
        }
694

695
        # Perform function evaluation
696
        res <- feval(private$inner_predict(data_idx), data)
697

698
        # Check for name correctness
699
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
700
          stop("lgb.Booster.eval: custom eval function should return a
701
702
            list with attribute (name, value, higher_better)");
        }
703

704
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
705
        res$data_name <- data_name
706
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
707
      }
708

709
      return(ret)
710

Guolin Ke's avatar
Guolin Ke committed
711
    }
712

Guolin Ke's avatar
Guolin Ke committed
713
714
715
  )
)

716
717
718
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
719
720
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
721
722
723
724
725
726
727
728
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
729
#' @param rawscore whether the prediction should be returned in the for of original untransformed
730
731
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
732
#' @param predleaf whether predict leaf index instead.
733
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
734
#' @param header only used for prediction for text file. True if text file has header
735
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
736
#'                prediction outputs per case.
James Lamb's avatar
James Lamb committed
737
738
#' @param ... Additional named arguments passed to the \code{predict()} method of
#'            the \code{lgb.Booster} object passed to \code{object}.
739
740
741
742
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
743
#'
744
745
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
746
#'
Guolin Ke's avatar
Guolin Ke committed
747
#' @examples
748
#' \donttest{
749
750
751
752
753
754
755
756
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
757
758
759
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
760
#'   , nrounds = 5L
761
#'   , valids = valids
762
763
#'   , min_data = 1L
#'   , learning_rate = 1.0
764
#' )
765
#' preds <- predict(model, test$data)
766
#' }
Guolin Ke's avatar
Guolin Ke committed
767
#' @export
James Lamb's avatar
James Lamb committed
768
769
predict.lgb.Booster <- function(object,
                                data,
770
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
771
772
773
774
775
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
776
                                reshape = FALSE,
James Lamb's avatar
James Lamb committed
777
                                ...) {
778

779
  if (!lgb.is.Booster(x = object)) {
780
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
781
  }
782

783
  # Return booster predictions
784
785
786
  return(
    object$predict(
      data = data
787
788
789
790
791
792
793
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
      , reshape = reshape
794
795
      , ...
    )
796
  )
Guolin Ke's avatar
Guolin Ke committed
797
798
}

799
800
#' @name lgb.load
#' @title Load LightGBM model
801
802
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
803
#' @param filename path of model file
804
#' @param model_str a str containing the model
805
#'
806
#' @return lgb.Booster
807
#'
Guolin Ke's avatar
Guolin Ke committed
808
#' @examples
809
#' \donttest{
810
811
812
813
814
815
816
817
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
818
819
820
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
821
#'   , nrounds = 5L
822
#'   , valids = valids
823
824
#'   , min_data = 1L
#'   , learning_rate = 1.0
825
#'   , early_stopping_rounds = 3L
826
#' )
827
828
829
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
830
831
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
832
#' }
Guolin Ke's avatar
Guolin Ke committed
833
#' @export
834
lgb.load <- function(filename = NULL, model_str = NULL) {
835

836
837
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
838

839
840
841
842
843
844
845
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
846
847
    return(invisible(Booster$new(modelfile = filename)))
  }
848

849
850
851
852
  if (model_str_provided) {
    if (!is.character(model_str)) {
      stop("lgb.load: model_str should be character")
    }
853
854
    return(invisible(Booster$new(model_str = model_str)))
  }
855

856
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
857
858
}

859
860
861
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
862
863
864
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
865
#'
866
#' @return lgb.Booster
867
#'
Guolin Ke's avatar
Guolin Ke committed
868
#' @examples
869
#' \donttest{
870
871
872
873
874
875
876
877
878
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
879
880
881
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
882
#'   , nrounds = 10L
883
#'   , valids = valids
884
885
886
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
887
#' )
888
#' lgb.save(model, tempfile(fileext = ".txt"))
889
#' }
Guolin Ke's avatar
Guolin Ke committed
890
#' @export
891
lgb.save <- function(booster, filename, num_iteration = NULL) {
892

893
  if (!lgb.is.Booster(x = booster)) {
894
895
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
896

897
898
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
899
  }
900

901
  # Store booster
902
903
904
905
906
907
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
908

Guolin Ke's avatar
Guolin Ke committed
909
910
}

911
912
913
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
914
915
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
916
#'
Guolin Ke's avatar
Guolin Ke committed
917
#' @return json format of model
918
#'
Guolin Ke's avatar
Guolin Ke committed
919
#' @examples
920
#' \donttest{
921
922
923
924
925
926
927
928
929
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
930
931
932
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
933
#'   , nrounds = 10L
934
#'   , valids = valids
935
936
937
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
938
#' )
939
#' json_model <- lgb.dump(model)
940
#' }
Guolin Ke's avatar
Guolin Ke committed
941
#' @export
942
lgb.dump <- function(booster, num_iteration = NULL) {
943

944
  if (!lgb.is.Booster(x = booster)) {
945
946
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
947

948
  # Return booster at requested iteration
949
  return(booster$dump_model(num_iteration =  num_iteration))
950

Guolin Ke's avatar
Guolin Ke committed
951
952
}

953
954
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
955
956
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
957
#' @param booster Object of class \code{lgb.Booster}
958
959
960
961
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
962
#' @param is_err TRUE will return evaluation error instead
963
#'
964
#' @return numeric vector of evaluation result
965
#'
966
#' @examples
967
#' \donttest{
968
#' # train a regression model
969
970
971
972
973
974
975
976
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
977
978
979
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
980
#'   , nrounds = 5L
981
#'   , valids = valids
982
983
#'   , min_data = 1L
#'   , learning_rate = 1.0
984
#' )
985
986
987
988
989
990
991
992
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
993
#' lgb.get.eval.result(model, "test", "l2")
994
#' }
Guolin Ke's avatar
Guolin Ke committed
995
#' @export
996
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
997

998
  # Check if booster is booster
999
  if (!lgb.is.Booster(x = booster)) {
1000
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1001
  }
1002

1003
  # Check if data and evaluation name are characters or not
1004
1005
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1006
  }
1007

1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1018
  }
1019

1020
  # Check if evaluation result is existing
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1032
1033
    stop("lgb.get.eval.result: wrong eval name")
  }
1034

1035
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1036

1037
  # Check if error is requested
1038
  if (is_err) {
1039
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1040
  }
1041

1042
  # Check if iteration is non existant
1043
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1044
1045
    return(as.numeric(result))
  }
1046

1047
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1048
  iters <- as.integer(iters)
1049
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1050
  iters <- iters - delta
1051

1052
  # Return requested result
1053
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1054
}