lgb.Booster.R 32.8 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom R6 R6Class
2
#' @importFrom utils modifyList
James Lamb's avatar
James Lamb committed
3
Booster <- R6::R6Class(
4
  classname = "lgb.Booster",
5
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
6
  public = list(
7

8
    best_iter = -1L,
9
    best_score = NA_real_,
10
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
11
    record_evals = list(),
12

13
14
    # Finalize will free up the handles
    finalize = function() {
15
16
17
18
19
      .Call(
        LGBM_BoosterFree_R
        , private$handle
      )
      private$handle <- NULL
20
      return(invisible(NULL))
21
    },
22

23
24
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
25
26
                          train_set = NULL,
                          modelfile = NULL,
27
                          model_str = NULL) {
28

29
      # Create parameters and handle
30
      handle <- NULL
31

32
33
      # Attempts to create a handle for the dataset
      try({
34

35
36
37
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
38
          if (!lgb.is.Dataset(train_set)) {
39
40
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
41
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
42
          params <- utils::modifyList(params, train_set$get_params())
43
          params_str <- lgb.params2str(params = params)
44
          # Store booster handle
45
          handle <- .Call(
46
            LGBM_BoosterCreate_R
47
            , train_set_handle
48
49
            , params_str
          )
50

51
52
          # Create private booster information
          private$train_set <- train_set
53
          private$train_set_version <- train_set$.__enclos_env__$private$version
54
          private$num_dataset <- 1L
55
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
56

57
58
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
59

60
            # Merge booster
61
62
            .Call(
              LGBM_BoosterMerge_R
63
64
65
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
66

67
          }
68

69
70
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
71

72
        } else if (!is.null(modelfile)) {
73

74
75
76
77
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
78

79
80
          modelfile <- path.expand(modelfile)

81
          # Create booster from model
82
          handle <- .Call(
83
            LGBM_BoosterCreateFromModelfile_R
84
            , modelfile
85
          )
86

87
        } else if (!is.null(model_str)) {
88

89
90
91
          # Do we have a model_str as character/raw?
          if (!is.raw(model_str) && !is.character(model_str)) {
            stop("lgb.Booster: Can only use a character/raw vector as model_str")
92
          }
93

94
          # Create booster from model
95
          handle <- .Call(
96
            LGBM_BoosterLoadModelFromString_R
97
            , model_str
98
          )
99

100
        } else {
101

102
          # Booster non existent
103
104
105
106
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
107

108
        }
109

110
      })
111

112
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
113
      if (isTRUE(lgb.is.null.handle(x = handle))) {
114

Guolin Ke's avatar
Guolin Ke committed
115
        stop("lgb.Booster: cannot create Booster handle")
116

Guolin Ke's avatar
Guolin Ke committed
117
      } else {
118

Guolin Ke's avatar
Guolin Ke committed
119
120
121
122
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
123
124
        .Call(
          LGBM_BoosterGetNumClasses_R
125
          , private$handle
126
          , private$num_class
127
        )
128

Guolin Ke's avatar
Guolin Ke committed
129
      }
130

131
132
      self$params <- params

133
134
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
135
    },
136

137
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
138
    set_train_data_name = function(name) {
139

140
      # Set name
Guolin Ke's avatar
Guolin Ke committed
141
      private$name_train_set <- name
142
      return(invisible(self))
143

Guolin Ke's avatar
Guolin Ke committed
144
    },
145

146
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
147
    add_valid = function(data, name) {
148

149
      if (!lgb.is.Dataset(data)) {
150
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
151
      }
152

Guolin Ke's avatar
Guolin Ke committed
153
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
154
155
156
157
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
158
      }
159

160
161
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
162
      }
163

164
      # Add validation data to booster
165
166
      .Call(
        LGBM_BoosterAddValidData_R
167
168
169
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
170

171
172
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
173
      private$num_dataset <- private$num_dataset + 1L
174
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
175

176
      return(invisible(self))
177

Guolin Ke's avatar
Guolin Ke committed
178
    },
179

Guolin Ke's avatar
Guolin Ke committed
180
    reset_parameter = function(params, ...) {
181

182
183
184
185
186
187
188
189
190
191
      additional_params <- list(...)
      if (length(additional_params) > 0L) {
        warning(paste0(
          "Booster$reset_parameter(): Found the following passed through '...': "
          , paste(names(additional_params), collapse = ", ")
          , ". These will be used, but in future releases of lightgbm, this warning will become an error. "
          , "Add these to 'params' instead."
        ))
      }

192
      if (methods::is(self$params, "list")) {
193
        params <- utils::modifyList(self$params, params)
194
195
      }

196
      params <- utils::modifyList(params, additional_params)
197
      params_str <- lgb.params2str(params = params)
198

199
200
      self$restore_handle()

201
202
      .Call(
        LGBM_BoosterResetParameter_R
203
204
205
        , private$handle
        , params_str
      )
206
      self$params <- params
207

208
      return(invisible(self))
209

Guolin Ke's avatar
Guolin Ke committed
210
    },
211

212
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
213
    update = function(train_set = NULL, fobj = NULL) {
214

215
216
217
218
219
220
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

Guolin Ke's avatar
Guolin Ke committed
221
      if (!is.null(train_set)) {
222

223
        if (!lgb.is.Dataset(train_set)) {
Guolin Ke's avatar
Guolin Ke committed
224
225
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
226

Guolin Ke's avatar
Guolin Ke committed
227
        if (!identical(train_set$predictor, private$init_predictor)) {
228
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
229
        }
230

231
232
        .Call(
          LGBM_BoosterResetTrainingData_R
233
234
235
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
236

237
        private$train_set <- train_set
238
        private$train_set_version <- train_set$.__enclos_env__$private$version
239

Guolin Ke's avatar
Guolin Ke committed
240
      }
241

242
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
243
      if (is.null(fobj)) {
244
245
246
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
247
        # Boost iteration from known objective
248
249
        .Call(
          LGBM_BoosterUpdateOneIter_R
250
251
          , private$handle
        )
252

Guolin Ke's avatar
Guolin Ke committed
253
      } else {
254

255
256
257
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
258
        if (!private$set_objective_to_none) {
259
          self$reset_parameter(params = list(objective = "none"))
260
          private$set_objective_to_none <- TRUE
261
        }
262
        # Perform objective calculation
263
        gpair <- fobj(private$inner_predict(1L), private$train_set)
264

265
        # Check for gradient and hessian as list
266
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
267
          stop("lgb.Booster.update: custom objective should
268
269
            return a list with attributes (hess, grad)")
        }
270

271
        # Return custom boosting gradient/hessian
272
273
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
274
275
276
277
278
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
279

Guolin Ke's avatar
Guolin Ke committed
280
      }
281

282
      # Loop through each iteration
283
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
284
285
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
286

287
      return(invisible(self))
288

Guolin Ke's avatar
Guolin Ke committed
289
    },
290

291
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
292
    rollback_one_iter = function() {
293

294
295
      self$restore_handle()

296
297
      .Call(
        LGBM_BoosterRollbackOneIter_R
298
299
        , private$handle
      )
300

301
      # Loop through each iteration
302
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
303
304
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
305

306
      return(invisible(self))
307

Guolin Ke's avatar
Guolin Ke committed
308
    },
309

310
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
311
    current_iter = function() {
312

313
314
      self$restore_handle()

315
      cur_iter <- 0L
316
317
318
319
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
320
      )
321
      return(cur_iter)
322

Guolin Ke's avatar
Guolin Ke committed
323
    },
324

325
    # Get upper bound
326
    upper_bound = function() {
327

328
329
      self$restore_handle()

330
      upper_bound <- 0.0
331
332
333
334
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
335
      )
336
      return(upper_bound)
337
338
339
340

    },

    # Get lower bound
341
    lower_bound = function() {
342

343
344
      self$restore_handle()

345
      lower_bound <- 0.0
346
347
348
349
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
350
      )
351
      return(lower_bound)
352
353
354

    },

355
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
356
    eval = function(data, name, feval = NULL) {
357

358
      if (!lgb.is.Dataset(data)) {
359
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
360
      }
361

362
      # Check for identical data
363
      data_idx <- 0L
364
      if (identical(data, private$train_set)) {
365
        data_idx <- 1L
366
      } else {
367

368
        # Check for validation data
369
        if (length(private$valid_sets) > 0L) {
370

371
          for (i in seq_along(private$valid_sets)) {
372

373
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
374
            if (identical(data, private$valid_sets[[i]])) {
375

376
              # Found identical data, skip
377
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
378
              break
379

Guolin Ke's avatar
Guolin Ke committed
380
            }
381

Guolin Ke's avatar
Guolin Ke committed
382
          }
383

Guolin Ke's avatar
Guolin Ke committed
384
        }
385

Guolin Ke's avatar
Guolin Ke committed
386
      }
387

388
      # Check if evaluation was not done
389
      if (data_idx == 0L) {
390

391
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
392
393
        self$add_valid(data, name)
        data_idx <- private$num_dataset
394

Guolin Ke's avatar
Guolin Ke committed
395
      }
396

397
      # Evaluate data
398
399
400
401
402
403
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
404
      )
405

Guolin Ke's avatar
Guolin Ke committed
406
    },
407

408
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
409
    eval_train = function(feval = NULL) {
410
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
411
    },
412

413
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
414
    eval_valid = function(feval = NULL) {
415

416
      ret <- list()
417

418
      if (length(private$valid_sets) <= 0L) {
419
420
        return(ret)
      }
421

422
      for (i in seq_along(private$valid_sets)) {
423
424
        ret <- append(
          x = ret
425
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
426
        )
Guolin Ke's avatar
Guolin Ke committed
427
      }
428

429
      return(ret)
430

Guolin Ke's avatar
Guolin Ke committed
431
    },
432

433
    # Save model
434
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
435

436
437
      self$restore_handle()

438
439
440
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
441

442
443
      filename <- path.expand(filename)

444
445
      .Call(
        LGBM_BoosterSaveModel_R
446
447
        , private$handle
        , as.integer(num_iteration)
448
        , as.integer(feature_importance_type)
449
        , filename
450
      )
451

452
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
453
    },
454

455
456
457
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {

      self$restore_handle()
458

459
460
461
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
462

463
      model_str <- .Call(
464
          LGBM_BoosterSaveModelToString_R
465
466
467
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
468
469
      )

470
471
472
473
      if (as_char) {
        model_str <- rawToChar(model_str)
      }

474
      return(model_str)
475

476
    },
477

478
    # Dump model in memory
479
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
480

481
482
      self$restore_handle()

483
484
485
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
486

487
      model_str <- .Call(
488
489
490
491
492
493
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
      )

494
      return(model_str)
495

Guolin Ke's avatar
Guolin Ke committed
496
    },
497

498
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
499
    predict = function(data,
500
                       start_iteration = NULL,
501
502
503
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
504
                       predcontrib = FALSE,
505
                       header = FALSE,
506
                       reshape = FALSE,
507
                       params = list(),
508
                       ...) {
509

510
511
      self$restore_handle()

512
513
514
515
516
517
518
519
520
521
      additional_params <- list(...)
      if (length(additional_params) > 0L) {
        warning(paste0(
          "Booster$predict(): Found the following passed through '...': "
          , paste(names(additional_params), collapse = ", ")
          , ". These will be used, but in future releases of lightgbm, this warning will become an error. "
          , "Add these to 'params' instead. See ?predict.lgb.Booster for documentation on how to call this function."
        ))
      }

522
523
524
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
525

526
527
528
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
529

530
      # Predict on new data
531
      params <- utils::modifyList(params, additional_params)
532
533
534
535
      predictor <- Predictor$new(
        modelfile = private$handle
        , params = params
      )
536
537
      return(
        predictor$predict(
538
539
540
541
542
543
544
545
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
          , reshape = reshape
546
        )
547
      )
548

549
    },
550

551
552
    # Transform into predictor
    to_predictor = function() {
553
      return(Predictor$new(modelfile = private$handle))
Guolin Ke's avatar
Guolin Ke committed
554
    },
555

556
557
    # Used for serialization
    raw = NULL,
558

559
560
561
562
563
564
    # Store serialized raw bytes in model object
    save_raw = function() {
      if (is.null(self$raw)) {
        self$raw <- self$save_model_to_string(NULL, as_char = FALSE)
      }
      return(invisible(NULL))
565

566
    },
567

568
569
    drop_raw = function() {
      self$raw <- NULL
570
      return(invisible(NULL))
571
    },
572

573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
    check_null_handle = function() {
      return(lgb.is.null.handle(private$handle))
    },

    restore_handle = function() {
      if (self$check_null_handle()) {
        if (is.null(self$raw)) {
          .Call(LGBM_NullBoosterHandleError_R)
        }
        private$handle <- .Call(LGBM_BoosterLoadModelFromString_R, self$raw)
      }
      return(invisible(NULL))
    },

    get_handle = function() {
      return(private$handle)
589
    }
590

Guolin Ke's avatar
Guolin Ke committed
591
592
  ),
  private = list(
593
594
595
596
597
598
599
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
600
601
    num_class = 1L,
    num_dataset = 0L,
602
603
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
604
    higher_better_inner_eval = NULL,
605
    set_objective_to_none = FALSE,
606
    train_set_version = 0L,
607
608
    # Predict data
    inner_predict = function(idx) {
609

610
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
611
      data_name <- private$name_train_set
612

613
614
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
615
      }
616

617
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
618
619
620
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
621

622
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
623
      if (is.null(private$predict_buffer[[data_name]])) {
624

625
        # Store predictions
626
        npred <- 0L
627
628
        .Call(
          LGBM_BoosterGetNumPredict_R
629
          , private$handle
630
          , as.integer(idx - 1L)
631
          , npred
632
        )
633
        private$predict_buffer[[data_name]] <- numeric(npred)
634

Guolin Ke's avatar
Guolin Ke committed
635
      }
636

637
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
638
      if (!private$is_predicted_cur_iter[[idx]]) {
639

640
        # Use buffer
641
642
        .Call(
          LGBM_BoosterGetPredict_R
643
          , private$handle
644
          , as.integer(idx - 1L)
645
          , private$predict_buffer[[data_name]]
646
        )
Guolin Ke's avatar
Guolin Ke committed
647
648
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
649

650
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
651
    },
652

653
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
654
    get_eval_info = function() {
655

Guolin Ke's avatar
Guolin Ke committed
656
      if (is.null(private$eval_names)) {
657
        eval_names <- .Call(
658
          LGBM_BoosterGetEvalNames_R
659
660
          , private$handle
        )
661

662
        if (length(eval_names) > 0L) {
663

664
          # Parse and store privately names
665
          private$eval_names <- eval_names
666
667
668

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
669
          metric_names <- gsub("@.*", "", eval_names)
670
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
671

Guolin Ke's avatar
Guolin Ke committed
672
        }
673

Guolin Ke's avatar
Guolin Ke committed
674
      }
675

676
      return(private$eval_names)
677

Guolin Ke's avatar
Guolin Ke committed
678
    },
679

Guolin Ke's avatar
Guolin Ke committed
680
    inner_eval = function(data_name, data_idx, feval = NULL) {
681

682
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
683
684
685
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
686

687
688
      self$restore_handle()

Guolin Ke's avatar
Guolin Ke committed
689
      private$get_eval_info()
690

Guolin Ke's avatar
Guolin Ke committed
691
      ret <- list()
692

693
      if (length(private$eval_names) > 0L) {
694

695
696
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
697
698
        .Call(
          LGBM_BoosterGetEval_R
699
          , private$handle
700
          , as.integer(data_idx - 1L)
701
          , tmp_vals
702
        )
703

704
        for (i in seq_along(private$eval_names)) {
705

706
707
708
709
710
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
711
          res$higher_better <- private$higher_better_inner_eval[i]
712
          ret <- append(ret, list(res))
713

Guolin Ke's avatar
Guolin Ke committed
714
        }
715

Guolin Ke's avatar
Guolin Ke committed
716
      }
717

718
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
719
      if (!is.null(feval)) {
720

721
        # Check if evaluation metric is a function
722
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
723
724
          stop("lgb.Booster.eval: feval should be a function")
        }
725

Guolin Ke's avatar
Guolin Ke committed
726
        data <- private$train_set
727

728
        # Check if data to assess is existing differently
729
730
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
731
        }
732

733
        # Perform function evaluation
734
        res <- feval(private$inner_predict(data_idx), data)
735

736
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
737
          stop("lgb.Booster.eval: custom eval function should return a
738
739
            list with attribute (name, value, higher_better)");
        }
740

741
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
742
        res$data_name <- data_name
743
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
744
      }
745

746
      return(ret)
747

Guolin Ke's avatar
Guolin Ke committed
748
    }
749

Guolin Ke's avatar
Guolin Ke committed
750
751
752
  )
)

753
754
755
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
756
#' @param object Object of class \code{lgb.Booster}
757
758
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or
#'             a character representing a path to a text file (CSV, TSV, or LibSVM)
759
760
761
762
763
764
765
766
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
767
#' @param rawscore whether the prediction should be returned in the for of original untransformed
768
769
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
770
#' @param predleaf whether predict leaf index instead.
771
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
772
#' @param header only used for prediction for text file. True if text file has header
773
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
774
#'                prediction outputs per case.
775
776
777
778
779
#' @param params a list of additional named parameters. See
#'               \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#'               the "Predict Parameters" section of the documentation} for a list of parameters and
#'               valid values.
#' @param ... Additional prediction parameters. NOTE: deprecated as of v3.3.0. Use \code{params} instead.
780
781
782
783
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
784
#'
785
786
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
787
#'
Guolin Ke's avatar
Guolin Ke committed
788
#' @examples
789
#' \donttest{
790
791
792
793
794
795
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
796
797
798
799
800
801
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
802
#' valids <- list(test = dtest)
803
804
805
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
806
#'   , nrounds = 5L
807
808
#'   , valids = valids
#' )
809
#' preds <- predict(model, test$data)
810
811
#'
#' # pass other prediction parameters
812
#' preds <- predict(
813
814
815
816
817
818
#'     model,
#'     test$data,
#'     params = list(
#'         predict_disable_shape_check = TRUE
#'    )
#' )
819
#' }
820
#' @importFrom utils modifyList
Guolin Ke's avatar
Guolin Ke committed
821
#' @export
James Lamb's avatar
James Lamb committed
822
823
predict.lgb.Booster <- function(object,
                                data,
824
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
825
826
827
828
829
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
830
                                reshape = FALSE,
831
                                params = list(),
James Lamb's avatar
James Lamb committed
832
                                ...) {
833

834
  if (!lgb.is.Booster(x = object)) {
835
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
836
  }
837

838
839
840
841
842
843
844
845
846
847
  additional_params <- list(...)
  if (length(additional_params) > 0L) {
    warning(paste0(
      "predict.lgb.Booster: Found the following passed through '...': "
      , paste(names(additional_params), collapse = ", ")
      , ". These will be used, but in future releases of lightgbm, this warning will become an error. "
      , "Add these to 'params' instead. See ?predict.lgb.Booster for documentation on how to call this function."
    ))
  }

848
849
850
  return(
    object$predict(
      data = data
851
852
853
854
855
856
857
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
      , reshape = reshape
858
      , params = utils::modifyList(params, additional_params)
859
    )
860
  )
Guolin Ke's avatar
Guolin Ke committed
861
862
}

863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `x`, returned as invisible.
#' @export
print.lgb.Booster <- function(x, ...) {
  # nolint start
  handle <- x$.__enclos_env__$private$handle
  handle_is_null <- lgb.is.null.handle(handle)

  if (!handle_is_null) {
    ntrees <- x$current_iter()
    if (ntrees == 1L) {
      cat("LightGBM Model (1 tree)\n")
    } else {
      cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
    }
  } else {
    cat("LightGBM Model\n")
  }

  if (!handle_is_null) {
    obj <- x$params$objective
    if (obj == "none") {
      obj <- "custom"
    }
    if (x$.__enclos_env__$private$num_class == 1L) {
      cat(sprintf("Objective: %s\n", obj))
    } else {
      cat(sprintf("Objective: %s (%d classes)\n"
          , obj
          , x$.__enclos_env__$private$num_class))
    }
  } else {
    cat("(Booster handle is invalid)\n")
  }

  if (!handle_is_null) {
    ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
    cat(sprintf("Fitted to dataset with %d columns\n", ncols))
  }
  # nolint end

  return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `object`, returned as invisible.
#' @export
summary.lgb.Booster <- function(object, ...) {
  print(object)
}

922
923
#' @name lgb.load
#' @title Load LightGBM model
924
925
#' @description Load LightGBM takes in either a file path or model string.
#'              If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
926
#' @param filename path of model file
927
#' @param model_str a str containing the model (as a `character` or `raw` vector)
928
#'
929
#' @return lgb.Booster
930
#'
Guolin Ke's avatar
Guolin Ke committed
931
#' @examples
932
#' \donttest{
933
934
935
936
937
938
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
939
940
941
942
943
944
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
945
#' valids <- list(test = dtest)
946
947
948
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
949
#'   , nrounds = 5L
950
#'   , valids = valids
951
#'   , early_stopping_rounds = 3L
952
#' )
953
954
955
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
956
957
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
958
#' }
Guolin Ke's avatar
Guolin Ke committed
959
#' @export
960
lgb.load <- function(filename = NULL, model_str = NULL) {
961

962
963
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
964

965
966
967
968
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
969
    filename <- path.expand(filename)
970
971
972
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
973
974
    return(invisible(Booster$new(modelfile = filename)))
  }
975

976
  if (model_str_provided) {
977
978
    if (!is.raw(model_str) && !is.character(model_str)) {
      stop("lgb.load: model_str should be a character/raw vector")
979
    }
980
981
    return(invisible(Booster$new(model_str = model_str)))
  }
982

983
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
984
985
}

986
987
988
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
989
990
991
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
992
#'
993
#' @return lgb.Booster
994
#'
Guolin Ke's avatar
Guolin Ke committed
995
#' @examples
996
#' \donttest{
997
998
999
1000
1001
1002
1003
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1004
1005
1006
1007
1008
1009
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1010
#' valids <- list(test = dtest)
1011
1012
1013
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1014
#'   , nrounds = 10L
1015
#'   , valids = valids
1016
#'   , early_stopping_rounds = 5L
1017
#' )
1018
#' lgb.save(model, tempfile(fileext = ".txt"))
1019
#' }
Guolin Ke's avatar
Guolin Ke committed
1020
#' @export
1021
lgb.save <- function(booster, filename, num_iteration = NULL) {
1022

1023
  if (!lgb.is.Booster(x = booster)) {
1024
1025
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1026

1027
1028
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
1029
  }
1030
  filename <- path.expand(filename)
1031

1032
  # Store booster
1033
1034
1035
1036
1037
1038
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
1039

Guolin Ke's avatar
Guolin Ke committed
1040
1041
}

1042
1043
1044
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
1045
1046
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1047
#'
Guolin Ke's avatar
Guolin Ke committed
1048
#' @return json format of model
1049
#'
Guolin Ke's avatar
Guolin Ke committed
1050
#' @examples
1051
#' \donttest{
1052
1053
1054
1055
1056
1057
1058
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1059
1060
1061
1062
1063
1064
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1065
#' valids <- list(test = dtest)
1066
1067
1068
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1069
#'   , nrounds = 10L
1070
#'   , valids = valids
1071
#'   , early_stopping_rounds = 5L
1072
#' )
1073
#' json_model <- lgb.dump(model)
1074
#' }
Guolin Ke's avatar
Guolin Ke committed
1075
#' @export
1076
lgb.dump <- function(booster, num_iteration = NULL) {
1077

1078
  if (!lgb.is.Booster(x = booster)) {
1079
1080
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1081

1082
  # Return booster at requested iteration
1083
  return(booster$dump_model(num_iteration =  num_iteration))
1084

Guolin Ke's avatar
Guolin Ke committed
1085
1086
}

1087
1088
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
1089
1090
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
1091
#' @param booster Object of class \code{lgb.Booster}
1092
1093
1094
1095
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
1096
#' @param is_err TRUE will return evaluation error instead
1097
#'
1098
#' @return numeric vector of evaluation result
1099
#'
1100
#' @examples
1101
#' \donttest{
1102
#' # train a regression model
1103
1104
1105
1106
1107
1108
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
1109
1110
1111
1112
1113
1114
#' params <- list(
#'   objective = "regression"
#'   , metric = "l2"
#'   , min_data = 1L
#'   , learning_rate = 1.0
#' )
1115
#' valids <- list(test = dtest)
1116
1117
1118
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1119
#'   , nrounds = 5L
1120
1121
#'   , valids = valids
#' )
1122
1123
1124
1125
1126
1127
1128
1129
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1130
#' lgb.get.eval.result(model, "test", "l2")
1131
#' }
Guolin Ke's avatar
Guolin Ke committed
1132
#' @export
1133
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1134

1135
  if (!lgb.is.Booster(x = booster)) {
1136
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1137
  }
1138

1139
1140
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1141
  }
1142

1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1153
  }
1154

1155
  # Check if evaluation result is existing
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1167
1168
    stop("lgb.get.eval.result: wrong eval name")
  }
1169

1170
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1171

1172
  # Check if error is requested
1173
  if (is_err) {
1174
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1175
  }
1176

1177
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1178
1179
    return(as.numeric(result))
  }
1180

1181
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1182
  iters <- as.integer(iters)
1183
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1184
  iters <- iters - delta
1185

1186
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1187
}