lgb.Booster.R 29.7 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
#' @importFrom R6 R6Class
Booster <- R6::R6Class(
3
  classname = "lgb.Booster",
4
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
5
  public = list(
6

7
    best_iter = -1L,
8
    best_score = NA_real_,
9
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
10
    record_evals = list(),
11

12
13
    # Finalize will free up the handles
    finalize = function() {
14

15
      # Check the need for freeing handle
16
      if (!lgb.is.null.handle(x = private$handle)) {
17

18
        # Freeing up handle
19
20
21
22
        .Call(
          LGBM_BoosterFree_R
          , private$handle
        )
Guolin Ke's avatar
Guolin Ke committed
23
        private$handle <- NULL
24

Guolin Ke's avatar
Guolin Ke committed
25
      }
26

27
28
      return(invisible(NULL))

29
    },
30

31
32
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
33
34
                          train_set = NULL,
                          modelfile = NULL,
35
                          model_str = NULL,
Guolin Ke's avatar
Guolin Ke committed
36
                          ...) {
37

38
39
      # Create parameters and handle
      params <- append(params, list(...))
40
      handle <- NULL
41

42
43
      # Attempts to create a handle for the dataset
      try({
44

45
46
47
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
48
          if (!lgb.check.r6.class(object = train_set, name = "lgb.Dataset")) {
49
50
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
51
52
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
          params <- modifyList(params, train_set$get_params())
53
          params_str <- lgb.params2str(params = params)
54
          # Store booster handle
55
          handle <- .Call(
56
            LGBM_BoosterCreate_R
57
            , train_set_handle
58
59
            , params_str
          )
60

61
62
          # Create private booster information
          private$train_set <- train_set
63
          private$train_set_version <- train_set$.__enclos_env__$private$version
64
          private$num_dataset <- 1L
65
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
66

67
68
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
69

70
            # Merge booster
71
72
            .Call(
              LGBM_BoosterMerge_R
73
74
75
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
76

77
          }
78

79
80
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
81

82
        } else if (!is.null(modelfile)) {
83

84
85
86
87
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
88

89
          # Create booster from model
90
          handle <- .Call(
91
            LGBM_BoosterCreateFromModelfile_R
92
            , modelfile
93
          )
94

95
        } else if (!is.null(model_str)) {
96

97
          # Do we have a model_str as character?
98
99
100
          if (!is.character(model_str)) {
            stop("lgb.Booster: Can only use a string as model_str")
          }
101

102
          # Create booster from model
103
          handle <- .Call(
104
            LGBM_BoosterLoadModelFromString_R
105
            , model_str
106
          )
107

108
        } else {
109

110
          # Booster non existent
111
112
113
114
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
115

116
        }
117

118
      })
119

120
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
121
      if (isTRUE(lgb.is.null.handle(x = handle))) {
122

Guolin Ke's avatar
Guolin Ke committed
123
        stop("lgb.Booster: cannot create Booster handle")
124

Guolin Ke's avatar
Guolin Ke committed
125
      } else {
126

Guolin Ke's avatar
Guolin Ke committed
127
128
129
130
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
131
132
        .Call(
          LGBM_BoosterGetNumClasses_R
133
          , private$handle
134
          , private$num_class
135
        )
136

Guolin Ke's avatar
Guolin Ke committed
137
      }
138

139
140
      self$params <- params

141
142
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
143
    },
144

145
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
146
    set_train_data_name = function(name) {
147

148
      # Set name
Guolin Ke's avatar
Guolin Ke committed
149
      private$name_train_set <- name
150
      return(invisible(self))
151

Guolin Ke's avatar
Guolin Ke committed
152
    },
153

154
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
155
    add_valid = function(data, name) {
156

157
      # Check if data is lgb.Dataset
158
      if (!lgb.check.r6.class(object = data, name = "lgb.Dataset")) {
159
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
160
      }
161

162
      # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
163
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
164
165
166
167
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
168
      }
169

170
      # Check if names are character
171
172
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
173
      }
174

175
      # Add validation data to booster
176
177
      .Call(
        LGBM_BoosterAddValidData_R
178
179
180
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
181

182
183
184
      # Store private information
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
185
      private$num_dataset <- private$num_dataset + 1L
186
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
187

188
      return(invisible(self))
189

Guolin Ke's avatar
Guolin Ke committed
190
    },
191

192
    # Reset parameters of booster
Guolin Ke's avatar
Guolin Ke committed
193
    reset_parameter = function(params, ...) {
194

195
196
197
198
199
      if (methods::is(self$params, "list")) {
        params <- modifyList(self$params, params)
      }

      params <- modifyList(params, list(...))
200
      params_str <- lgb.params2str(params = params)
201

202
203
      .Call(
        LGBM_BoosterResetParameter_R
204
205
206
        , private$handle
        , params_str
      )
207
      self$params <- params
208

209
      return(invisible(self))
210

Guolin Ke's avatar
Guolin Ke committed
211
    },
212

213
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
214
    update = function(train_set = NULL, fobj = NULL) {
215

216
217
218
219
220
221
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

222
      # Check if training set is not null
Guolin Ke's avatar
Guolin Ke committed
223
      if (!is.null(train_set)) {
224

225
        # Check if training set is lgb.Dataset
226
        if (!lgb.check.r6.class(object = train_set, name = "lgb.Dataset")) {
Guolin Ke's avatar
Guolin Ke committed
227
228
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
229

230
        # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
231
        if (!identical(train_set$predictor, private$init_predictor)) {
232
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
233
        }
234

235
        # Reset training data on booster
236
237
        .Call(
          LGBM_BoosterResetTrainingData_R
238
239
240
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
241

242
        # Store private train set
243
        private$train_set <- train_set
244
        private$train_set_version <- train_set$.__enclos_env__$private$version
245

Guolin Ke's avatar
Guolin Ke committed
246
      }
247

248
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
249
      if (is.null(fobj)) {
250
251
252
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
253
        # Boost iteration from known objective
254
255
        .Call(
          LGBM_BoosterUpdateOneIter_R
256
257
          , private$handle
        )
258

Guolin Ke's avatar
Guolin Ke committed
259
      } else {
260

261
262
263
264
        # Check if objective is function
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
265
        if (!private$set_objective_to_none) {
266
          self$reset_parameter(params = list(objective = "none"))
267
          private$set_objective_to_none <- TRUE
268
        }
269
        # Perform objective calculation
270
        gpair <- fobj(private$inner_predict(1L), private$train_set)
271

272
        # Check for gradient and hessian as list
273
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
274
          stop("lgb.Booster.update: custom objective should
275
276
            return a list with attributes (hess, grad)")
        }
277

278
        # Return custom boosting gradient/hessian
279
280
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
281
282
283
284
285
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
286

Guolin Ke's avatar
Guolin Ke committed
287
      }
288

289
      # Loop through each iteration
290
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
291
292
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
293

294
      return(invisible(self))
295

Guolin Ke's avatar
Guolin Ke committed
296
    },
297

298
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
299
    rollback_one_iter = function() {
300

301
      # Return one iteration behind
302
303
      .Call(
        LGBM_BoosterRollbackOneIter_R
304
305
        , private$handle
      )
306

307
      # Loop through each iteration
308
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
309
310
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
311

312
      return(invisible(self))
313

Guolin Ke's avatar
Guolin Ke committed
314
    },
315

316
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
317
    current_iter = function() {
318

319
      cur_iter <- 0L
320
321
322
323
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
324
      )
325
      return(cur_iter)
326

Guolin Ke's avatar
Guolin Ke committed
327
    },
328

329
    # Get upper bound
330
    upper_bound = function() {
331

332
      upper_bound <- 0.0
333
334
335
336
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
337
      )
338
      return(upper_bound)
339
340
341
342

    },

    # Get lower bound
343
    lower_bound = function() {
344

345
      lower_bound <- 0.0
346
347
348
349
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
350
      )
351
      return(lower_bound)
352
353
354

    },

355
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
356
    eval = function(data, name, feval = NULL) {
357

358
      # Check if dataset is lgb.Dataset
359
      if (!lgb.check.r6.class(object = data, name = "lgb.Dataset")) {
360
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
361
      }
362

363
      # Check for identical data
364
      data_idx <- 0L
365
      if (identical(data, private$train_set)) {
366
        data_idx <- 1L
367
      } else {
368

369
        # Check for validation data
370
        if (length(private$valid_sets) > 0L) {
371

372
          # Loop through each validation set
373
          for (i in seq_along(private$valid_sets)) {
374

375
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
376
            if (identical(data, private$valid_sets[[i]])) {
377

378
              # Found identical data, skip
379
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
380
              break
381

Guolin Ke's avatar
Guolin Ke committed
382
            }
383

Guolin Ke's avatar
Guolin Ke committed
384
          }
385

Guolin Ke's avatar
Guolin Ke committed
386
        }
387

Guolin Ke's avatar
Guolin Ke committed
388
      }
389

390
      # Check if evaluation was not done
391
      if (data_idx == 0L) {
392

393
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
394
395
        self$add_valid(data, name)
        data_idx <- private$num_dataset
396

Guolin Ke's avatar
Guolin Ke committed
397
      }
398

399
      # Evaluate data
400
401
402
403
404
405
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
406
      )
407

Guolin Ke's avatar
Guolin Ke committed
408
    },
409

410
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
411
    eval_train = function(feval = NULL) {
412
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
413
    },
414

415
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
416
    eval_valid = function(feval = NULL) {
417

418
      # Create ret list
419
      ret <- list()
420

421
      # Check if validation is empty
422
      if (length(private$valid_sets) <= 0L) {
423
424
        return(ret)
      }
425

426
      # Loop through each validation set
427
      for (i in seq_along(private$valid_sets)) {
428
429
        ret <- append(
          x = ret
430
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
431
        )
Guolin Ke's avatar
Guolin Ke committed
432
      }
433

434
      return(ret)
435

Guolin Ke's avatar
Guolin Ke committed
436
    },
437

438
    # Save model
439
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
440

441
442
443
444
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
445

446
      # Save booster model
447
448
      .Call(
        LGBM_BoosterSaveModel_R
449
450
        , private$handle
        , as.integer(num_iteration)
451
        , as.integer(feature_importance_type)
452
        , filename
453
      )
454

455
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
456
    },
457

458
    # Save model to string
459
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
460

461
462
463
464
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
465

466
      model_str <- .Call(
467
          LGBM_BoosterSaveModelToString_R
468
469
470
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
471
472
      )

473
      return(model_str)
474

475
    },
476

477
    # Dump model in memory
478
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
479

480
481
482
483
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
484

485
      model_str <- .Call(
486
487
488
489
490
491
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
      )

492
      return(model_str)
493

Guolin Ke's avatar
Guolin Ke committed
494
    },
495

496
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
497
    predict = function(data,
498
                       start_iteration = NULL,
499
500
501
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
502
                       predcontrib = FALSE,
503
                       header = FALSE,
504
505
                       reshape = FALSE,
                       ...) {
506

507
      # Check if number of iteration is non existent
508
509
510
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
511
      # Check if start iteration is non existent
512
513
514
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
515

516
      # Predict on new data
517
518
519
520
521
      params <- list(...)
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
      )
522
523
      return(
        predictor$predict(
524
525
526
527
528
529
530
531
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
          , reshape = reshape
532
        )
533
      )
534

535
    },
536

537
538
    # Transform into predictor
    to_predictor = function() {
539
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
540
    },
541

542
    # Used for save
543
    raw = NA,
544

545
    # Save model to temporary file for in-memory saving
546
    save = function() {
547

548
      # Overwrite model in object
549
      self$raw <- self$save_model_to_string(NULL)
550

551
552
      return(invisible(NULL))

553
    }
554

Guolin Ke's avatar
Guolin Ke committed
555
556
  ),
  private = list(
557
558
559
560
561
562
563
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
564
565
    num_class = 1L,
    num_dataset = 0L,
566
567
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
568
    higher_better_inner_eval = NULL,
569
    set_objective_to_none = FALSE,
570
    train_set_version = 0L,
571
572
    # Predict data
    inner_predict = function(idx) {
573

574
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
575
      data_name <- private$name_train_set
576

577
      # Check for id bigger than 1
578
579
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
580
      }
581

582
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
583
584
585
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
586

587
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
588
      if (is.null(private$predict_buffer[[data_name]])) {
589

590
        # Store predictions
591
        npred <- 0L
592
593
        .Call(
          LGBM_BoosterGetNumPredict_R
594
          , private$handle
595
          , as.integer(idx - 1L)
596
          , npred
597
        )
598
        private$predict_buffer[[data_name]] <- numeric(npred)
599

Guolin Ke's avatar
Guolin Ke committed
600
      }
601

602
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
603
      if (!private$is_predicted_cur_iter[[idx]]) {
604

605
        # Use buffer
606
607
        .Call(
          LGBM_BoosterGetPredict_R
608
          , private$handle
609
          , as.integer(idx - 1L)
610
          , private$predict_buffer[[data_name]]
611
        )
Guolin Ke's avatar
Guolin Ke committed
612
613
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
614

615
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
616
    },
617

618
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
619
    get_eval_info = function() {
620

621
      # Check for evaluation names emptiness
Guolin Ke's avatar
Guolin Ke committed
622
      if (is.null(private$eval_names)) {
623
        eval_names <- .Call(
624
          LGBM_BoosterGetEvalNames_R
625
626
          , private$handle
        )
627

628
        # Check names' length
629
        if (length(eval_names) > 0L) {
630

631
          # Parse and store privately names
632
          private$eval_names <- eval_names
633
634
635

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
636
          metric_names <- gsub("@.*", "", eval_names)
637
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
638

Guolin Ke's avatar
Guolin Ke committed
639
        }
640

Guolin Ke's avatar
Guolin Ke committed
641
      }
642

643
      return(private$eval_names)
644

Guolin Ke's avatar
Guolin Ke committed
645
    },
646

647
    # Perform inner evaluation
Guolin Ke's avatar
Guolin Ke committed
648
    inner_eval = function(data_name, data_idx, feval = NULL) {
649

650
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
651
652
653
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
654

655
      # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
656
      private$get_eval_info()
657

658
      # Prepare return
Guolin Ke's avatar
Guolin Ke committed
659
      ret <- list()
660

661
      # Check evaluation names existence
662
      if (length(private$eval_names) > 0L) {
663

664
665
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
666
667
        .Call(
          LGBM_BoosterGetEval_R
668
          , private$handle
669
          , as.integer(data_idx - 1L)
670
          , tmp_vals
671
        )
672

673
        # Loop through all evaluation names
674
        for (i in seq_along(private$eval_names)) {
675

676
677
678
679
680
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
681
          res$higher_better <- private$higher_better_inner_eval[i]
682
          ret <- append(ret, list(res))
683

Guolin Ke's avatar
Guolin Ke committed
684
        }
685

Guolin Ke's avatar
Guolin Ke committed
686
      }
687

688
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
689
      if (!is.null(feval)) {
690

691
        # Check if evaluation metric is a function
692
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
693
694
          stop("lgb.Booster.eval: feval should be a function")
        }
695

696
        # Prepare data
Guolin Ke's avatar
Guolin Ke committed
697
        data <- private$train_set
698

699
        # Check if data to assess is existing differently
700
701
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
702
        }
703

704
        # Perform function evaluation
705
        res <- feval(private$inner_predict(data_idx), data)
706

707
        # Check for name correctness
708
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
709
          stop("lgb.Booster.eval: custom eval function should return a
710
711
            list with attribute (name, value, higher_better)");
        }
712

713
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
714
        res$data_name <- data_name
715
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
716
      }
717

718
      return(ret)
719

Guolin Ke's avatar
Guolin Ke committed
720
    }
721

Guolin Ke's avatar
Guolin Ke committed
722
723
724
  )
)

725
726
727
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
728
729
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
730
731
732
733
734
735
736
737
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
738
#' @param rawscore whether the prediction should be returned in the for of original untransformed
739
740
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
741
#' @param predleaf whether predict leaf index instead.
742
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
743
#' @param header only used for prediction for text file. True if text file has header
744
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
745
#'                prediction outputs per case.
James Lamb's avatar
James Lamb committed
746
747
#' @param ... Additional named arguments passed to the \code{predict()} method of
#'            the \code{lgb.Booster} object passed to \code{object}.
748
749
750
751
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
752
#'
753
754
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
755
#'
Guolin Ke's avatar
Guolin Ke committed
756
#' @examples
757
#' \donttest{
758
759
760
761
762
763
764
765
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
766
767
768
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
769
#'   , nrounds = 5L
770
#'   , valids = valids
771
772
#'   , min_data = 1L
#'   , learning_rate = 1.0
773
#' )
774
#' preds <- predict(model, test$data)
775
#' }
Guolin Ke's avatar
Guolin Ke committed
776
#' @export
James Lamb's avatar
James Lamb committed
777
778
predict.lgb.Booster <- function(object,
                                data,
779
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
780
781
782
783
784
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
785
                                reshape = FALSE,
James Lamb's avatar
James Lamb committed
786
                                ...) {
787

788
  if (!lgb.is.Booster(x = object)) {
789
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
790
  }
791

792
  # Return booster predictions
793
794
795
  return(
    object$predict(
      data = data
796
797
798
799
800
801
802
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
      , reshape = reshape
803
804
      , ...
    )
805
  )
Guolin Ke's avatar
Guolin Ke committed
806
807
}

808
809
#' @name lgb.load
#' @title Load LightGBM model
810
811
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
812
#' @param filename path of model file
813
#' @param model_str a str containing the model
814
#'
815
#' @return lgb.Booster
816
#'
Guolin Ke's avatar
Guolin Ke committed
817
#' @examples
818
#' \donttest{
819
820
821
822
823
824
825
826
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
827
828
829
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
830
#'   , nrounds = 5L
831
#'   , valids = valids
832
833
#'   , min_data = 1L
#'   , learning_rate = 1.0
834
#'   , early_stopping_rounds = 3L
835
#' )
836
837
838
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
839
840
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
841
#' }
Guolin Ke's avatar
Guolin Ke committed
842
#' @export
843
lgb.load <- function(filename = NULL, model_str = NULL) {
844

845
846
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
847

848
849
850
851
852
853
854
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
855
856
    return(invisible(Booster$new(modelfile = filename)))
  }
857

858
859
860
861
  if (model_str_provided) {
    if (!is.character(model_str)) {
      stop("lgb.load: model_str should be character")
    }
862
863
    return(invisible(Booster$new(model_str = model_str)))
  }
864

865
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
866
867
}

868
869
870
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
871
872
873
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
874
#'
875
#' @return lgb.Booster
876
#'
Guolin Ke's avatar
Guolin Ke committed
877
#' @examples
878
#' \donttest{
879
880
881
882
883
884
885
886
887
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
888
889
890
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
891
#'   , nrounds = 10L
892
#'   , valids = valids
893
894
895
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
896
#' )
897
#' lgb.save(model, tempfile(fileext = ".txt"))
898
#' }
Guolin Ke's avatar
Guolin Ke committed
899
#' @export
900
lgb.save <- function(booster, filename, num_iteration = NULL) {
901

902
  if (!lgb.is.Booster(x = booster)) {
903
904
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
905

906
907
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
908
  }
909

910
  # Store booster
911
912
913
914
915
916
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
917

Guolin Ke's avatar
Guolin Ke committed
918
919
}

920
921
922
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
923
924
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
925
#'
Guolin Ke's avatar
Guolin Ke committed
926
#' @return json format of model
927
#'
Guolin Ke's avatar
Guolin Ke committed
928
#' @examples
929
#' \donttest{
930
931
932
933
934
935
936
937
938
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
939
940
941
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
942
#'   , nrounds = 10L
943
#'   , valids = valids
944
945
946
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
947
#' )
948
#' json_model <- lgb.dump(model)
949
#' }
Guolin Ke's avatar
Guolin Ke committed
950
#' @export
951
lgb.dump <- function(booster, num_iteration = NULL) {
952

953
  if (!lgb.is.Booster(x = booster)) {
954
955
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
956

957
  # Return booster at requested iteration
958
  return(booster$dump_model(num_iteration =  num_iteration))
959

Guolin Ke's avatar
Guolin Ke committed
960
961
}

962
963
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
964
965
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
966
#' @param booster Object of class \code{lgb.Booster}
967
968
969
970
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
971
#' @param is_err TRUE will return evaluation error instead
972
#'
973
#' @return numeric vector of evaluation result
974
#'
975
#' @examples
976
#' \donttest{
977
#' # train a regression model
978
979
980
981
982
983
984
985
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
986
987
988
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
989
#'   , nrounds = 5L
990
#'   , valids = valids
991
992
#'   , min_data = 1L
#'   , learning_rate = 1.0
993
#' )
994
995
996
997
998
999
1000
1001
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1002
#' lgb.get.eval.result(model, "test", "l2")
1003
#' }
Guolin Ke's avatar
Guolin Ke committed
1004
#' @export
1005
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1006

1007
  # Check if booster is booster
1008
  if (!lgb.is.Booster(x = booster)) {
1009
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1010
  }
1011

1012
  # Check if data and evaluation name are characters or not
1013
1014
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1015
  }
1016

1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1027
  }
1028

1029
  # Check if evaluation result is existing
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1041
1042
    stop("lgb.get.eval.result: wrong eval name")
  }
1043

1044
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1045

1046
  # Check if error is requested
1047
  if (is_err) {
1048
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1049
  }
1050

1051
  # Check if iteration is non existant
1052
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1053
1054
    return(as.numeric(result))
  }
1055

1056
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1057
  iters <- as.integer(iters)
1058
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1059
  iters <- iters - delta
1060

1061
  # Return requested result
1062
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1063
}