lgb.Booster.R 28.3 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
#' @importFrom R6 R6Class
Booster <- R6::R6Class(
3
  classname = "lgb.Booster",
4
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
5
  public = list(
6

7
    best_iter = -1L,
8
    best_score = NA,
Guolin Ke's avatar
Guolin Ke committed
9
    record_evals = list(),
10

11
12
    # Finalize will free up the handles
    finalize = function() {
13

14
      # Check the need for freeing handle
15
      if (!lgb.is.null.handle(private$handle)) {
16

17
        # Freeing up handle
18
        lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
Guolin Ke's avatar
Guolin Ke committed
19
        private$handle <- NULL
20

Guolin Ke's avatar
Guolin Ke committed
21
      }
22

23
    },
24

25
26
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
27
28
                          train_set = NULL,
                          modelfile = NULL,
29
                          model_str = NULL,
Guolin Ke's avatar
Guolin Ke committed
30
                          ...) {
31

32
33
      # Create parameters and handle
      params <- append(params, list(...))
Guolin Ke's avatar
Guolin Ke committed
34
      handle <- 0.0
35

36
37
      # Attempts to create a handle for the dataset
      try({
38

39
40
41
42
43
44
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
          if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
45
46
47
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
          params <- modifyList(params, train_set$get_params())
          params_str <- lgb.params2str(params)
48
          # Store booster handle
49
50
51
          handle <- lgb.call(
            "LGBM_BoosterCreate_R"
            , ret = handle
52
            , train_set_handle
53
54
            , params_str
          )
55

56
57
          # Create private booster information
          private$train_set <- train_set
58
          private$train_set_version <- train_set$.__enclos_env__$private$version
59
          private$num_dataset <- 1L
60
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
61

62
63
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
64

65
            # Merge booster
66
67
68
69
70
71
            lgb.call(
              "LGBM_BoosterMerge_R"
              , ret = NULL
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
72

73
          }
74

75
76
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
77

78
        } else if (!is.null(modelfile)) {
79

80
81
82
83
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
84

85
          # Create booster from model
86
87
88
89
90
          handle <- lgb.call(
            "LGBM_BoosterCreateFromModelfile_R"
            , ret = handle
            , lgb.c_str(modelfile)
          )
91

92
        } else if (!is.null(model_str)) {
93

94
          # Do we have a model_str as character?
95
96
97
          if (!is.character(model_str)) {
            stop("lgb.Booster: Can only use a string as model_str")
          }
98

99
          # Create booster from model
100
101
102
103
104
          handle <- lgb.call(
            "LGBM_BoosterLoadModelFromString_R"
            , ret = handle
            , lgb.c_str(model_str)
          )
105

106
        } else {
107

108
          # Booster non existent
109
110
111
112
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
113

114
        }
115

116
      })
117

118
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
Guolin Ke's avatar
Guolin Ke committed
119
      if (lgb.is.null.handle(handle)) {
120

Guolin Ke's avatar
Guolin Ke committed
121
        stop("lgb.Booster: cannot create Booster handle")
122

Guolin Ke's avatar
Guolin Ke committed
123
      } else {
124

Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
129
130
131
132
133
        private$num_class <- lgb.call(
          "LGBM_BoosterGetNumClasses_R"
          , ret = private$num_class
          , private$handle
        )
134

Guolin Ke's avatar
Guolin Ke committed
135
      }
136

Guolin Ke's avatar
Guolin Ke committed
137
    },
138

139
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
140
    set_train_data_name = function(name) {
141

142
      # Set name
Guolin Ke's avatar
Guolin Ke committed
143
      private$name_train_set <- name
144
      return(invisible(self))
145

Guolin Ke's avatar
Guolin Ke committed
146
    },
147

148
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
149
    add_valid = function(data, name) {
150

151
      # Check if data is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
152
      if (!lgb.check.r6.class(data, "lgb.Dataset")) {
153
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
154
      }
155

156
      # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
157
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
158
159
160
161
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
162
      }
163

164
      # Check if names are character
165
166
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
167
      }
168

169
      # Add validation data to booster
170
171
172
173
174
175
      lgb.call(
        "LGBM_BoosterAddValidData_R"
        , ret = NULL
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
176

177
178
179
      # Store private information
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
180
      private$num_dataset <- private$num_dataset + 1L
181
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
182

183
      # Return self
184
      return(invisible(self))
185

Guolin Ke's avatar
Guolin Ke committed
186
    },
187

188
    # Reset parameters of booster
Guolin Ke's avatar
Guolin Ke committed
189
    reset_parameter = function(params, ...) {
190

191
192
      # Append parameters
      params <- append(params, list(...))
193
      params_str <- lgb.params2str(params)
194

195
      # Reset parameters
196
197
198
199
200
201
      lgb.call(
        "LGBM_BoosterResetParameter_R"
        , ret = NULL
        , private$handle
        , params_str
      )
202

203
      # Return self
204
      return(invisible(self))
205

Guolin Ke's avatar
Guolin Ke committed
206
    },
207

208
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
209
    update = function(train_set = NULL, fobj = NULL) {
210

211
212
213
214
215
216
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

217
      # Check if training set is not null
Guolin Ke's avatar
Guolin Ke committed
218
      if (!is.null(train_set)) {
219

220
        # Check if training set is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
221
222
223
        if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
224

225
        # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
226
        if (!identical(train_set$predictor, private$init_predictor)) {
227
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
228
        }
229

230
        # Reset training data on booster
231
232
233
234
235
236
        lgb.call(
          "LGBM_BoosterResetTrainingData_R"
          , ret = NULL
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
237

238
        # Store private train set
239
        private$train_set <- train_set
240
        private$train_set_version <- train_set$.__enclos_env__$private$version
241

Guolin Ke's avatar
Guolin Ke committed
242
      }
243

244
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
245
      if (is.null(fobj)) {
246
247
248
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
249
        # Boost iteration from known objective
250
        ret <- lgb.call("LGBM_BoosterUpdateOneIter_R", ret = NULL, private$handle)
251

Guolin Ke's avatar
Guolin Ke committed
252
      } else {
253

254
255
256
257
        # Check if objective is function
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
258
        if (!private$set_objective_to_none) {
259
          self$reset_parameter(params = list(objective = "none"))
260
          private$set_objective_to_none <- TRUE
261
        }
262
        # Perform objective calculation
263
        gpair <- fobj(private$inner_predict(1L), private$train_set)
264

265
        # Check for gradient and hessian as list
266
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
267
          stop("lgb.Booster.update: custom objective should
268
269
            return a list with attributes (hess, grad)")
        }
270

271
        # Return custom boosting gradient/hessian
272
273
274
275
276
277
278
279
        ret <- lgb.call(
          "LGBM_BoosterUpdateOneIterCustom_R"
          , ret = NULL
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
280

Guolin Ke's avatar
Guolin Ke committed
281
      }
282

283
      # Loop through each iteration
284
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
285
286
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
287

288
      return(ret)
289

Guolin Ke's avatar
Guolin Ke committed
290
    },
291

292
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
293
    rollback_one_iter = function() {
294

295
      # Return one iteration behind
296
297
298
299
300
      lgb.call(
        "LGBM_BoosterRollbackOneIter_R"
        , ret = NULL
        , private$handle
      )
301

302
      # Loop through each iteration
303
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
304
305
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
306

307
      # Return self
308
      return(invisible(self))
309

Guolin Ke's avatar
Guolin Ke committed
310
    },
311

312
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
313
    current_iter = function() {
314

315
      cur_iter <- 0L
316
317
318
319
320
      lgb.call(
        "LGBM_BoosterGetCurrentIteration_R"
        , ret = cur_iter
        , private$handle
      )
321

Guolin Ke's avatar
Guolin Ke committed
322
    },
323

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    # Get upper bound
    upper_bound_ = function() {

      upper_bound <- 0L
      lgb.call(
        "LGBM_BoosterGetUpperBoundValue_R"
        , ret = upper_bound
        , private$handle
      )

    },

    # Get lower bound
    lower_bound_ = function() {

      lower_bound <- 0L
      lgb.call(
        "LGBM_BoosterGetLowerBoundValue_R"
        , ret = upper_bound
        , private$handle
      )

    },

348
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
349
    eval = function(data, name, feval = NULL) {
350

351
      # Check if dataset is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
352
      if (!lgb.check.r6.class(data, "lgb.Dataset")) {
353
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
354
      }
355

356
      # Check for identical data
357
      data_idx <- 0L
358
      if (identical(data, private$train_set)) {
359
        data_idx <- 1L
360
      } else {
361

362
        # Check for validation data
363
        if (length(private$valid_sets) > 0L) {
364

365
          # Loop through each validation set
366
          for (i in seq_along(private$valid_sets)) {
367

368
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
369
            if (identical(data, private$valid_sets[[i]])) {
370

371
              # Found identical data, skip
372
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
373
              break
374

Guolin Ke's avatar
Guolin Ke committed
375
            }
376

Guolin Ke's avatar
Guolin Ke committed
377
          }
378

Guolin Ke's avatar
Guolin Ke committed
379
        }
380

Guolin Ke's avatar
Guolin Ke committed
381
      }
382

383
      # Check if evaluation was not done
384
      if (data_idx == 0L) {
385

386
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
387
388
        self$add_valid(data, name)
        data_idx <- private$num_dataset
389

Guolin Ke's avatar
Guolin Ke committed
390
      }
391

392
      # Evaluate data
393
      private$inner_eval(name, data_idx, feval)
394

Guolin Ke's avatar
Guolin Ke committed
395
    },
396

397
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
398
    eval_train = function(feval = NULL) {
399
      private$inner_eval(private$name_train_set, 1L, feval)
Guolin Ke's avatar
Guolin Ke committed
400
    },
401

402
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
403
    eval_valid = function(feval = NULL) {
404

405
      # Create ret list
406
      ret <- list()
407

408
      # Check if validation is empty
409
      if (length(private$valid_sets) <= 0L) {
410
411
        return(ret)
      }
412

413
      # Loop through each validation set
414
      for (i in seq_along(private$valid_sets)) {
415
416
        ret <- append(
          x = ret
417
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
418
        )
Guolin Ke's avatar
Guolin Ke committed
419
      }
420

421
422
      # Return ret
      return(ret)
423

Guolin Ke's avatar
Guolin Ke committed
424
    },
425

426
    # Save model
Guolin Ke's avatar
Guolin Ke committed
427
    save_model = function(filename, num_iteration = NULL) {
428

429
430
431
432
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
433

434
      # Save booster model
435
436
437
438
439
440
441
      lgb.call(
        "LGBM_BoosterSaveModel_R"
        , ret = NULL
        , private$handle
        , as.integer(num_iteration)
        , lgb.c_str(filename)
      )
442

443
      # Return self
444
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
445
    },
446

447
448
    # Save model to string
    save_model_to_string = function(num_iteration = NULL) {
449

450
451
452
453
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
454

455
      # Return model string
456
457
458
459
460
      return(lgb.call.return.str(
        "LGBM_BoosterSaveModelToString_R"
        , private$handle
        , as.integer(num_iteration)
      ))
461

462
    },
463

464
    # Dump model in memory
Guolin Ke's avatar
Guolin Ke committed
465
    dump_model = function(num_iteration = NULL) {
466

467
468
469
470
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
471

472
      # Return dumped model
473
474
475
476
477
      lgb.call.return.str(
        "LGBM_BoosterDumpModel_R"
        , private$handle
        , as.integer(num_iteration)
      )
478

Guolin Ke's avatar
Guolin Ke committed
479
    },
480

481
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
482
    predict = function(data,
483
484
485
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
486
                       predcontrib = FALSE,
487
                       header = FALSE,
488
                       reshape = FALSE, ...) {
489

490
491
492
493
      # Check if number of iteration is  non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
494

495
      # Predict on new data
496
      predictor <- Predictor$new(private$handle, ...)
497
      predictor$predict(data, num_iteration, rawscore, predleaf, predcontrib, header, reshape)
498

499
    },
500

501
502
503
    # Transform into predictor
    to_predictor = function() {
      Predictor$new(private$handle)
Guolin Ke's avatar
Guolin Ke committed
504
    },
505

506
    # Used for save
507
    raw = NA,
508

509
    # Save model to temporary file for in-memory saving
510
    save = function() {
511

512
      # Overwrite model in object
513
      self$raw <- self$save_model_to_string(NULL)
514

515
    }
516

Guolin Ke's avatar
Guolin Ke committed
517
518
  ),
  private = list(
519
520
521
522
523
524
525
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
526
527
    num_class = 1L,
    num_dataset = 0L,
528
529
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
530
    higher_better_inner_eval = NULL,
531
    set_objective_to_none = FALSE,
532
    train_set_version = 0L,
533
534
    # Predict data
    inner_predict = function(idx) {
535

536
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
537
      data_name <- private$name_train_set
538

539
      # Check for id bigger than 1
540
541
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
542
      }
543

544
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
545
546
547
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
548

549
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
550
      if (is.null(private$predict_buffer[[data_name]])) {
551

552
        # Store predictions
553
        npred <- 0L
554
555
556
557
        npred <- lgb.call(
          "LGBM_BoosterGetNumPredict_R"
          , ret = npred
          , private$handle
558
          , as.integer(idx - 1L)
559
        )
560
        private$predict_buffer[[data_name]] <- numeric(npred)
561

Guolin Ke's avatar
Guolin Ke committed
562
      }
563

564
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
565
      if (!private$is_predicted_cur_iter[[idx]]) {
566

567
        # Use buffer
568
569
570
571
        private$predict_buffer[[data_name]] <- lgb.call(
          "LGBM_BoosterGetPredict_R"
          , ret = private$predict_buffer[[data_name]]
          , private$handle
572
          , as.integer(idx - 1L)
573
        )
Guolin Ke's avatar
Guolin Ke committed
574
575
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
576

577
578
      # Return prediction buffer
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
579
    },
580

581
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
582
    get_eval_info = function() {
583

584
      # Check for evaluation names emptiness
Guolin Ke's avatar
Guolin Ke committed
585
      if (is.null(private$eval_names)) {
586

587
        # Get evaluation names
588
589
590
591
        names <- lgb.call.return.str(
          "LGBM_BoosterGetEvalNames_R"
          , private$handle
        )
592

593
        # Check names' length
594
        if (nchar(names) > 0L) {
595

596
          # Parse and store privately names
597
          names <- strsplit(names, "\t")[[1L]]
Guolin Ke's avatar
Guolin Ke committed
598
          private$eval_names <- names
Belinda Trotta's avatar
Belinda Trotta committed
599
          private$higher_better_inner_eval <- grepl("^ndcg|^map|^auc", names)
600

Guolin Ke's avatar
Guolin Ke committed
601
        }
602

Guolin Ke's avatar
Guolin Ke committed
603
      }
604

605
606
      # Return evaluation names
      return(private$eval_names)
607

Guolin Ke's avatar
Guolin Ke committed
608
    },
609

610
    # Perform inner evaluation
Guolin Ke's avatar
Guolin Ke committed
611
    inner_eval = function(data_name, data_idx, feval = NULL) {
612

613
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
614
615
616
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
617

618
      # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
619
      private$get_eval_info()
620

621
      # Prepare return
Guolin Ke's avatar
Guolin Ke committed
622
      ret <- list()
623

624
      # Check evaluation names existence
625
      if (length(private$eval_names) > 0L) {
626

627
628
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
629
630
631
632
        tmp_vals <- lgb.call(
          "LGBM_BoosterGetEval_R"
          , ret = tmp_vals
          , private$handle
633
          , as.integer(data_idx - 1L)
634
        )
635

636
        # Loop through all evaluation names
637
        for (i in seq_along(private$eval_names)) {
638

639
640
641
642
643
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
644
          res$higher_better <- private$higher_better_inner_eval[i]
645
          ret <- append(ret, list(res))
646

Guolin Ke's avatar
Guolin Ke committed
647
        }
648

Guolin Ke's avatar
Guolin Ke committed
649
      }
650

651
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
652
      if (!is.null(feval)) {
653

654
        # Check if evaluation metric is a function
655
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
656
657
          stop("lgb.Booster.eval: feval should be a function")
        }
658

659
        # Prepare data
Guolin Ke's avatar
Guolin Ke committed
660
        data <- private$train_set
661

662
        # Check if data to assess is existing differently
663
664
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
665
        }
666

667
        # Perform function evaluation
668
        res <- feval(private$inner_predict(data_idx), data)
669

670
        # Check for name correctness
671
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
672
          stop("lgb.Booster.eval: custom eval function should return a
673
674
            list with attribute (name, value, higher_better)");
        }
675

676
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
677
        res$data_name <- data_name
678
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
679
      }
680

681
682
      # Return ret
      return(ret)
683

Guolin Ke's avatar
Guolin Ke committed
684
    }
685

Guolin Ke's avatar
Guolin Ke committed
686
687
688
  )
)

689
690
691
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
692
693
694
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
695
#' @param rawscore whether the prediction should be returned in the for of original untransformed
696
697
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
698
#' @param predleaf whether predict leaf index instead.
699
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
700
#' @param header only used for prediction for text file. True if text file has header
701
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
702
#'                prediction outputs per case.
James Lamb's avatar
James Lamb committed
703
704
#' @param ... Additional named arguments passed to the \code{predict()} method of
#'            the \code{lgb.Booster} object passed to \code{object}.
705
706
707
708
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
709
#'
710
711
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
712
#'
Guolin Ke's avatar
Guolin Ke committed
713
#' @examples
714
715
716
717
718
719
720
721
722
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
723
724
725
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
726
#'   , nrounds = 10L
727
#'   , valids = valids
728
729
730
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
731
#' )
732
#' preds <- predict(model, test$data)
Guolin Ke's avatar
Guolin Ke committed
733
#' @export
James Lamb's avatar
James Lamb committed
734
735
736
737
738
739
740
predict.lgb.Booster <- function(object,
                                data,
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
741
                                reshape = FALSE,
James Lamb's avatar
James Lamb committed
742
                                ...) {
743

744
  # Check booster existence
745
746
  if (!lgb.is.Booster(object)) {
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
747
  }
748

749
  # Return booster predictions
750
751
752
753
754
755
756
757
758
759
  object$predict(
    data
    , num_iteration
    , rawscore
    , predleaf
    , predcontrib
    , header
    , reshape
    , ...
  )
Guolin Ke's avatar
Guolin Ke committed
760
761
}

762
763
764
765
#' @name lgb.load
#' @title Load LightGBM model
#' @description  Load LightGBM takes in either a file path or model string.
#'               If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
766
#' @param filename path of model file
767
#' @param model_str a str containing the model
768
#'
769
#' @return lgb.Booster
770
#'
Guolin Ke's avatar
Guolin Ke committed
771
#' @examples
772
773
774
775
776
777
778
779
780
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
781
782
783
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
784
#'   , nrounds = 10L
785
#'   , valids = valids
786
787
788
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
789
#' )
790
#' lgb.save(model, "model.txt")
791
792
793
#' load_booster <- lgb.load(filename = "model.txt")
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
794
#'
Guolin Ke's avatar
Guolin Ke committed
795
#' @export
796
lgb.load <- function(filename = NULL, model_str = NULL) {
797

798
799
800
  if (is.null(filename) && is.null(model_str)) {
    stop("lgb.load: either filename or model_str must be given")
  }
801

802
803
  # Load from filename
  if (!is.null(filename) && !is.character(filename)) {
804
805
    stop("lgb.load: filename should be character")
  }
806

807
  # Return new booster
808
809
810
811
812
813
  if (!is.null(filename) && !file.exists(filename)) {
    stop("lgb.load: file does not exist for supplied filename")
  }
  if (!is.null(filename)) {
    return(invisible(Booster$new(modelfile = filename)))
  }
814

815
816
817
  # Load from model_str
  if (!is.null(model_str) && !is.character(model_str)) {
    stop("lgb.load: model_str should be character")
818
  }
819
  # Return new booster
820
821
822
  if (!is.null(model_str)) {
    return(invisible(Booster$new(model_str = model_str)))
  }
823

Guolin Ke's avatar
Guolin Ke committed
824
825
}

826
827
828
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
829
830
831
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
832
#'
833
#' @return lgb.Booster
834
#'
Guolin Ke's avatar
Guolin Ke committed
835
#' @examples
836
837
838
839
840
841
842
843
844
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
845
846
847
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
848
#'   , nrounds = 10L
849
#'   , valids = valids
850
851
852
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
853
#' )
854
#' lgb.save(model, "model.txt")
Guolin Ke's avatar
Guolin Ke committed
855
#' @export
856
lgb.save <- function(booster, filename, num_iteration = NULL) {
857

858
859
860
861
  # Check if booster is booster
  if (!lgb.is.Booster(booster)) {
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
862

863
  # Check if file name is character
864
865
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
866
  }
867

868
  # Store booster
869
  invisible(booster$save_model(filename, num_iteration))
870

Guolin Ke's avatar
Guolin Ke committed
871
872
}

873
874
875
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
876
877
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
878
#'
Guolin Ke's avatar
Guolin Ke committed
879
#' @return json format of model
880
#'
Guolin Ke's avatar
Guolin Ke committed
881
#' @examples
882
883
884
885
886
887
888
889
890
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
891
892
893
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
894
#'   , nrounds = 10L
895
#'   , valids = valids
896
897
898
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
899
#' )
900
#' json_model <- lgb.dump(model)
901
#'
Guolin Ke's avatar
Guolin Ke committed
902
#' @export
903
lgb.dump <- function(booster, num_iteration = NULL) {
904

905
906
907
908
  # Check if booster is booster
  if (!lgb.is.Booster(booster)) {
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
909

910
  # Return booster at requested iteration
Guolin Ke's avatar
Guolin Ke committed
911
  booster$dump_model(num_iteration)
912

Guolin Ke's avatar
Guolin Ke committed
913
914
}

915
916
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
917
918
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
919
#' @param booster Object of class \code{lgb.Booster}
920
921
922
923
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
924
#' @param is_err TRUE will return evaluation error instead
925
#'
Guolin Ke's avatar
Guolin Ke committed
926
#' @return vector of evaluation result
927
#'
928
#' @examples
929
#' # train a regression model
930
931
932
933
934
935
936
937
938
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
939
940
941
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
942
#'   , nrounds = 10L
943
#'   , valids = valids
944
945
946
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
947
#' )
948
949
950
951
952
953
954
955
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
956
#' lgb.get.eval.result(model, "test", "l2")
Guolin Ke's avatar
Guolin Ke committed
957
#' @export
958
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
959

960
  # Check if booster is booster
961
962
  if (!lgb.is.Booster(booster)) {
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
963
  }
964

965
  # Check if data and evaluation name are characters or not
966
967
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
968
  }
969

970
971
972
973
974
975
976
977
978
979
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
980
  }
981

982
  # Check if evaluation result is existing
983
984
985
986
987
988
989
990
991
992
993
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
994
995
    stop("lgb.get.eval.result: wrong eval name")
  }
996

997
  # Create result
Guolin Ke's avatar
Guolin Ke committed
998
  result <- booster$record_evals[[data_name]][[eval_name]]$eval
999

1000
  # Check if error is requested
1001
  if (is_err) {
Guolin Ke's avatar
Guolin Ke committed
1002
1003
    result <- booster$record_evals[[data_name]][[eval_name]]$eval_err
  }
1004

1005
  # Check if iteration is non existant
1006
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1007
1008
    return(as.numeric(result))
  }
1009

1010
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1011
  iters <- as.integer(iters)
1012
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1013
  iters <- iters - delta
1014

1015
  # Return requested result
1016
  as.numeric(result[iters])
Guolin Ke's avatar
Guolin Ke committed
1017
}