lgb.Booster.R 28 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
#' @importFrom R6 R6Class
Booster <- R6::R6Class(
3
  classname = "lgb.Booster",
4
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
5
  public = list(
6

7
    best_iter = -1L,
8
    best_score = NA_real_,
9
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
10
    record_evals = list(),
11

12
13
    # Finalize will free up the handles
    finalize = function() {
14
15
16
17
18
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
19
      return(invisible(NULL))
20
    },
21

22
23
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
24
25
                          train_set = NULL,
                          modelfile = NULL,
26
                          model_str = NULL) {
27

28
      # Create parameters and handle
29
      handle <- NULL
30

31
32
      # Attempts to create a handle for the dataset
      try({
33

34
35
36
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
37
          if (!lgb.is.Dataset(train_set)) {
38
39
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
40
41
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
          params <- modifyList(params, train_set$get_params())
42
          params_str <- lgb.params2str(params = params)
43
          # Store booster handle
44
          handle <- .Call(
45
            LGBM_BoosterCreate_R
46
            , train_set_handle
47
48
            , params_str
          )
49

50
51
          # Create private booster information
          private$train_set <- train_set
52
          private$train_set_version <- train_set$.__enclos_env__$private$version
53
          private$num_dataset <- 1L
54
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
55

56
57
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
58

59
            # Merge booster
60
61
            .Call(
              LGBM_BoosterMerge_R
62
63
64
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
65

66
          }
67

68
69
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
70

71
        } else if (!is.null(modelfile)) {
72

73
74
75
76
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
77

78
          # Create booster from model
79
          handle <- .Call(
80
            LGBM_BoosterCreateFromModelfile_R
81
            , modelfile
82
          )
83

84
        } else if (!is.null(model_str)) {
85

86
          # Do we have a model_str as character?
87
88
89
          if (!is.character(model_str)) {
            stop("lgb.Booster: Can only use a string as model_str")
          }
90

91
          # Create booster from model
92
          handle <- .Call(
93
            LGBM_BoosterLoadModelFromString_R
94
            , model_str
95
          )
96

97
        } else {
98

99
          # Booster non existent
100
101
102
103
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
104

105
        }
106

107
      })
108

109
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
110
      if (isTRUE(lgb.is.null.handle(x = handle))) {
111

Guolin Ke's avatar
Guolin Ke committed
112
        stop("lgb.Booster: cannot create Booster handle")
113

Guolin Ke's avatar
Guolin Ke committed
114
      } else {
115

Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
120
121
        .Call(
          LGBM_BoosterGetNumClasses_R
122
          , private$handle
123
          , private$num_class
124
        )
125

Guolin Ke's avatar
Guolin Ke committed
126
      }
127

128
129
      self$params <- params

130
131
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
132
    },
133

134
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
135
    set_train_data_name = function(name) {
136

137
      # Set name
Guolin Ke's avatar
Guolin Ke committed
138
      private$name_train_set <- name
139
      return(invisible(self))
140

Guolin Ke's avatar
Guolin Ke committed
141
    },
142

143
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
144
    add_valid = function(data, name) {
145

146
      if (!lgb.is.Dataset(data)) {
147
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
148
      }
149

Guolin Ke's avatar
Guolin Ke committed
150
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
151
152
153
154
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
155
      }
156

157
158
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
159
      }
160

161
      # Add validation data to booster
162
163
      .Call(
        LGBM_BoosterAddValidData_R
164
165
166
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
167

168
169
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
170
      private$num_dataset <- private$num_dataset + 1L
171
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
172

173
      return(invisible(self))
174

Guolin Ke's avatar
Guolin Ke committed
175
    },
176

Guolin Ke's avatar
Guolin Ke committed
177
    reset_parameter = function(params, ...) {
178

179
180
181
182
183
      if (methods::is(self$params, "list")) {
        params <- modifyList(self$params, params)
      }

      params <- modifyList(params, list(...))
184
      params_str <- lgb.params2str(params = params)
185

186
187
      .Call(
        LGBM_BoosterResetParameter_R
188
189
190
        , private$handle
        , params_str
      )
191
      self$params <- params
192

193
      return(invisible(self))
194

Guolin Ke's avatar
Guolin Ke committed
195
    },
196

197
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
198
    update = function(train_set = NULL, fobj = NULL) {
199

200
201
202
203
204
205
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

Guolin Ke's avatar
Guolin Ke committed
206
      if (!is.null(train_set)) {
207

208
        if (!lgb.is.Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
209
210
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
211

Guolin Ke's avatar
Guolin Ke committed
212
        if (!identical(train_set$predictor, private$init_predictor)) {
213
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
214
        }
215

216
217
        .Call(
          LGBM_BoosterResetTrainingData_R
218
219
220
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
221

222
        private$train_set <- train_set
223
        private$train_set_version <- train_set$.__enclos_env__$private$version
224

Guolin Ke's avatar
Guolin Ke committed
225
      }
226

227
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
228
      if (is.null(fobj)) {
229
230
231
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
232
        # Boost iteration from known objective
233
234
        .Call(
          LGBM_BoosterUpdateOneIter_R
235
236
          , private$handle
        )
237

Guolin Ke's avatar
Guolin Ke committed
238
      } else {
239

240
241
242
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
243
        if (!private$set_objective_to_none) {
244
          self$reset_parameter(params = list(objective = "none"))
245
          private$set_objective_to_none <- TRUE
246
        }
247
        # Perform objective calculation
248
        gpair <- fobj(private$inner_predict(1L), private$train_set)
249

250
        # Check for gradient and hessian as list
251
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
252
          stop("lgb.Booster.update: custom objective should
253
254
            return a list with attributes (hess, grad)")
        }
255

256
        # Return custom boosting gradient/hessian
257
258
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
259
260
261
262
263
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
264

Guolin Ke's avatar
Guolin Ke committed
265
      }
266

267
      # Loop through each iteration
268
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
269
270
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
271

272
      return(invisible(self))
273

Guolin Ke's avatar
Guolin Ke committed
274
    },
275

276
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
277
    rollback_one_iter = function() {
278

279
280
      .Call(
        LGBM_BoosterRollbackOneIter_R
281
282
        , private$handle
      )
283

284
      # Loop through each iteration
285
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
286
287
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
288

289
      return(invisible(self))
290

Guolin Ke's avatar
Guolin Ke committed
291
    },
292

293
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
294
    current_iter = function() {
295

296
      cur_iter <- 0L
297
298
299
300
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
301
      )
302
      return(cur_iter)
303

Guolin Ke's avatar
Guolin Ke committed
304
    },
305

306
    # Get upper bound
307
    upper_bound = function() {
308

309
      upper_bound <- 0.0
310
311
312
313
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
314
      )
315
      return(upper_bound)
316
317
318
319

    },

    # Get lower bound
320
    lower_bound = function() {
321

322
      lower_bound <- 0.0
323
324
325
326
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
327
      )
328
      return(lower_bound)
329
330
331

    },

332
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
333
    eval = function(data, name, feval = NULL) {
334

335
      if (!lgb.is.Dataset(data)) {
336
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
337
      }
338

339
      # Check for identical data
340
      data_idx <- 0L
341
      if (identical(data, private$train_set)) {
342
        data_idx <- 1L
343
      } else {
344

345
        # Check for validation data
346
        if (length(private$valid_sets) > 0L) {
347

348
          for (i in seq_along(private$valid_sets)) {
349

350
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
351
            if (identical(data, private$valid_sets[[i]])) {
352

353
              # Found identical data, skip
354
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
355
              break
356

Guolin Ke's avatar
Guolin Ke committed
357
            }
358

Guolin Ke's avatar
Guolin Ke committed
359
          }
360

Guolin Ke's avatar
Guolin Ke committed
361
        }
362

Guolin Ke's avatar
Guolin Ke committed
363
      }
364

365
      # Check if evaluation was not done
366
      if (data_idx == 0L) {
367

368
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
369
370
        self$add_valid(data, name)
        data_idx <- private$num_dataset
371

Guolin Ke's avatar
Guolin Ke committed
372
      }
373

374
      # Evaluate data
375
376
377
378
379
380
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
381
      )
382

Guolin Ke's avatar
Guolin Ke committed
383
    },
384

385
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
386
    eval_train = function(feval = NULL) {
387
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
388
    },
389

390
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
391
    eval_valid = function(feval = NULL) {
392

393
      ret <- list()
394

395
      if (length(private$valid_sets) <= 0L) {
396
397
        return(ret)
      }
398

399
      for (i in seq_along(private$valid_sets)) {
400
401
        ret <- append(
          x = ret
402
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
403
        )
Guolin Ke's avatar
Guolin Ke committed
404
      }
405

406
      return(ret)
407

Guolin Ke's avatar
Guolin Ke committed
408
    },
409

410
    # Save model
411
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
412

413
414
415
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
416

417
418
      .Call(
        LGBM_BoosterSaveModel_R
419
420
        , private$handle
        , as.integer(num_iteration)
421
        , as.integer(feature_importance_type)
422
        , filename
423
      )
424

425
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
426
    },
427

428
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
429

430
431
432
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
433

434
      model_str <- .Call(
435
          LGBM_BoosterSaveModelToString_R
436
437
438
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
439
440
      )

441
      return(model_str)
442

443
    },
444

445
    # Dump model in memory
446
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
447

448
449
450
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
451

452
      model_str <- .Call(
453
454
455
456
457
458
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
      )

459
      return(model_str)
460

Guolin Ke's avatar
Guolin Ke committed
461
    },
462

463
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
464
    predict = function(data,
465
                       start_iteration = NULL,
466
467
468
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
469
                       predcontrib = FALSE,
470
                       header = FALSE,
471
472
                       reshape = FALSE,
                       ...) {
473

474
475
476
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
477

478
479
480
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
481

482
      # Predict on new data
483
484
485
486
487
      params <- list(...)
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
      )
488
489
      return(
        predictor$predict(
490
491
492
493
494
495
496
497
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
          , reshape = reshape
498
        )
499
      )
500

501
    },
502

503
504
    # Transform into predictor
    to_predictor = function() {
505
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
506
    },
507

508
    # Used for save
509
    raw = NA,
510

511
    # Save model to temporary file for in-memory saving
512
    save = function() {
513

514
      # Overwrite model in object
515
      self$raw <- self$save_model_to_string(NULL)
516

517
518
      return(invisible(NULL))

519
    }
520

Guolin Ke's avatar
Guolin Ke committed
521
522
  ),
  private = list(
523
524
525
526
527
528
529
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
530
531
    num_class = 1L,
    num_dataset = 0L,
532
533
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
534
    higher_better_inner_eval = NULL,
535
    set_objective_to_none = FALSE,
536
    train_set_version = 0L,
537
538
    # Predict data
    inner_predict = function(idx) {
539

540
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
541
      data_name <- private$name_train_set
542

543
544
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
545
      }
546

547
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
548
549
550
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
551

552
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
553
      if (is.null(private$predict_buffer[[data_name]])) {
554

555
        # Store predictions
556
        npred <- 0L
557
558
        .Call(
          LGBM_BoosterGetNumPredict_R
559
          , private$handle
560
          , as.integer(idx - 1L)
561
          , npred
562
        )
563
        private$predict_buffer[[data_name]] <- numeric(npred)
564

Guolin Ke's avatar
Guolin Ke committed
565
      }
566

567
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
568
      if (!private$is_predicted_cur_iter[[idx]]) {
569

570
        # Use buffer
571
572
        .Call(
          LGBM_BoosterGetPredict_R
573
          , private$handle
574
          , as.integer(idx - 1L)
575
          , private$predict_buffer[[data_name]]
576
        )
Guolin Ke's avatar
Guolin Ke committed
577
578
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
579

580
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
581
    },
582

583
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
584
    get_eval_info = function() {
585

Guolin Ke's avatar
Guolin Ke committed
586
      if (is.null(private$eval_names)) {
587
        eval_names <- .Call(
588
          LGBM_BoosterGetEvalNames_R
589
590
          , private$handle
        )
591

592
        if (length(eval_names) > 0L) {
593

594
          # Parse and store privately names
595
          private$eval_names <- eval_names
596
597
598

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
599
          metric_names <- gsub("@.*", "", eval_names)
600
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
601

Guolin Ke's avatar
Guolin Ke committed
602
        }
603

Guolin Ke's avatar
Guolin Ke committed
604
      }
605

606
      return(private$eval_names)
607

Guolin Ke's avatar
Guolin Ke committed
608
    },
609

Guolin Ke's avatar
Guolin Ke committed
610
    inner_eval = function(data_name, data_idx, feval = NULL) {
611

612
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
613
614
615
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
616

Guolin Ke's avatar
Guolin Ke committed
617
      private$get_eval_info()
618

Guolin Ke's avatar
Guolin Ke committed
619
      ret <- list()
620

621
      if (length(private$eval_names) > 0L) {
622

623
624
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
625
626
        .Call(
          LGBM_BoosterGetEval_R
627
          , private$handle
628
          , as.integer(data_idx - 1L)
629
          , tmp_vals
630
        )
631

632
        for (i in seq_along(private$eval_names)) {
633

634
635
636
637
638
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
639
          res$higher_better <- private$higher_better_inner_eval[i]
640
          ret <- append(ret, list(res))
641

Guolin Ke's avatar
Guolin Ke committed
642
        }
643

Guolin Ke's avatar
Guolin Ke committed
644
      }
645

646
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
647
      if (!is.null(feval)) {
648

649
        # Check if evaluation metric is a function
650
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
651
652
          stop("lgb.Booster.eval: feval should be a function")
        }
653

Guolin Ke's avatar
Guolin Ke committed
654
        data <- private$train_set
655

656
        # Check if data to assess is existing differently
657
658
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
659
        }
660

661
        # Perform function evaluation
662
        res <- feval(private$inner_predict(data_idx), data)
663

664
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
665
          stop("lgb.Booster.eval: custom eval function should return a
666
667
            list with attribute (name, value, higher_better)");
        }
668

669
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
670
        res$data_name <- data_name
671
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
672
      }
673

674
      return(ret)
675

Guolin Ke's avatar
Guolin Ke committed
676
    }
677

Guolin Ke's avatar
Guolin Ke committed
678
679
680
  )
)

681
682
683
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
684
#' @param object Object of class \code{lgb.Booster}
685
686
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or
#'             a character representing a path to a text file (CSV, TSV, or LibSVM)
687
688
689
690
691
692
693
694
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
695
#' @param rawscore whether the prediction should be returned in the for of original untransformed
696
697
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
698
#' @param predleaf whether predict leaf index instead.
699
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
700
#' @param header only used for prediction for text file. True if text file has header
701
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
702
#'                prediction outputs per case.
James Lamb's avatar
James Lamb committed
703
704
#' @param ... Additional named arguments passed to the \code{predict()} method of
#'            the \code{lgb.Booster} object passed to \code{object}.
705
706
707
708
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
709
#'
710
711
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
712
#'
Guolin Ke's avatar
Guolin Ke committed
713
#' @examples
714
#' \donttest{
715
716
717
718
719
720
721
722
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
723
724
725
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
726
#'   , nrounds = 5L
727
#'   , valids = valids
728
729
#'   , min_data = 1L
#'   , learning_rate = 1.0
730
#' )
731
#' preds <- predict(model, test$data)
732
#' }
Guolin Ke's avatar
Guolin Ke committed
733
#' @export
James Lamb's avatar
James Lamb committed
734
735
predict.lgb.Booster <- function(object,
                                data,
736
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
737
738
739
740
741
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
742
                                reshape = FALSE,
James Lamb's avatar
James Lamb committed
743
                                ...) {
744

745
  if (!lgb.is.Booster(x = object)) {
746
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
747
  }
748

749
750
751
  return(
    object$predict(
      data = data
752
753
754
755
756
757
758
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
      , reshape = reshape
759
760
      , ...
    )
761
  )
Guolin Ke's avatar
Guolin Ke committed
762
763
}

764
765
#' @name lgb.load
#' @title Load LightGBM model
766
767
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
768
#' @param filename path of model file
769
#' @param model_str a str containing the model
770
#'
771
#' @return lgb.Booster
772
#'
Guolin Ke's avatar
Guolin Ke committed
773
#' @examples
774
#' \donttest{
775
776
777
778
779
780
781
782
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
783
784
785
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
786
#'   , nrounds = 5L
787
#'   , valids = valids
788
789
#'   , min_data = 1L
#'   , learning_rate = 1.0
790
#'   , early_stopping_rounds = 3L
791
#' )
792
793
794
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
795
796
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
797
#' }
Guolin Ke's avatar
Guolin Ke committed
798
#' @export
799
lgb.load <- function(filename = NULL, model_str = NULL) {
800

801
802
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
803

804
805
806
807
808
809
810
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
811
812
    return(invisible(Booster$new(modelfile = filename)))
  }
813

814
815
816
817
  if (model_str_provided) {
    if (!is.character(model_str)) {
      stop("lgb.load: model_str should be character")
    }
818
819
    return(invisible(Booster$new(model_str = model_str)))
  }
820

821
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
822
823
}

824
825
826
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
827
828
829
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
830
#'
831
#' @return lgb.Booster
832
#'
Guolin Ke's avatar
Guolin Ke committed
833
#' @examples
834
#' \donttest{
835
836
837
838
839
840
841
842
843
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
844
845
846
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
847
#'   , nrounds = 10L
848
#'   , valids = valids
849
850
851
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
852
#' )
853
#' lgb.save(model, tempfile(fileext = ".txt"))
854
#' }
Guolin Ke's avatar
Guolin Ke committed
855
#' @export
856
lgb.save <- function(booster, filename, num_iteration = NULL) {
857

858
  if (!lgb.is.Booster(x = booster)) {
859
860
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
861

862
863
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
864
  }
865

866
  # Store booster
867
868
869
870
871
872
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
873

Guolin Ke's avatar
Guolin Ke committed
874
875
}

876
877
878
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
879
880
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
881
#'
Guolin Ke's avatar
Guolin Ke committed
882
#' @return json format of model
883
#'
Guolin Ke's avatar
Guolin Ke committed
884
#' @examples
885
#' \donttest{
886
887
888
889
890
891
892
893
894
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
895
896
897
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
898
#'   , nrounds = 10L
899
#'   , valids = valids
900
901
902
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
903
#' )
904
#' json_model <- lgb.dump(model)
905
#' }
Guolin Ke's avatar
Guolin Ke committed
906
#' @export
907
lgb.dump <- function(booster, num_iteration = NULL) {
908

909
  if (!lgb.is.Booster(x = booster)) {
910
911
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
912

913
  # Return booster at requested iteration
914
  return(booster$dump_model(num_iteration =  num_iteration))
915

Guolin Ke's avatar
Guolin Ke committed
916
917
}

918
919
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
920
921
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
922
#' @param booster Object of class \code{lgb.Booster}
923
924
925
926
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
927
#' @param is_err TRUE will return evaluation error instead
928
#'
929
#' @return numeric vector of evaluation result
930
#'
931
#' @examples
932
#' \donttest{
933
#' # train a regression model
934
935
936
937
938
939
940
941
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
942
943
944
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
945
#'   , nrounds = 5L
946
#'   , valids = valids
947
948
#'   , min_data = 1L
#'   , learning_rate = 1.0
949
#' )
950
951
952
953
954
955
956
957
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
958
#' lgb.get.eval.result(model, "test", "l2")
959
#' }
Guolin Ke's avatar
Guolin Ke committed
960
#' @export
961
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
962

963
  if (!lgb.is.Booster(x = booster)) {
964
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
965
  }
966

967
968
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
969
  }
970

971
972
973
974
975
976
977
978
979
980
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
981
  }
982

983
  # Check if evaluation result is existing
984
985
986
987
988
989
990
991
992
993
994
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
995
996
    stop("lgb.get.eval.result: wrong eval name")
  }
997

998
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
999

1000
  # Check if error is requested
1001
  if (is_err) {
1002
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1003
  }
1004

1005
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1006
1007
    return(as.numeric(result))
  }
1008

1009
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1010
  iters <- as.integer(iters)
1011
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1012
  iters <- iters - delta
1013

1014
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1015
}