lgb.Booster.R 26.2 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
#' @importFrom R6 R6Class
Booster <- R6::R6Class(
3
  classname = "lgb.Booster",
4
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
5
  public = list(
6

7
    best_iter = -1,
8
    best_score = NA,
Guolin Ke's avatar
Guolin Ke committed
9
    record_evals = list(),
10

11
12
    # Finalize will free up the handles
    finalize = function() {
13

14
      # Check the need for freeing handle
15
      if (!lgb.is.null.handle(private$handle)) {
16

17
        # Freeing up handle
18
        lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
Guolin Ke's avatar
Guolin Ke committed
19
        private$handle <- NULL
20

Guolin Ke's avatar
Guolin Ke committed
21
      }
22

23
    },
24

25
26
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
27
28
                          train_set = NULL,
                          modelfile = NULL,
29
                          model_str = NULL,
Guolin Ke's avatar
Guolin Ke committed
30
                          ...) {
31

32
33
      # Create parameters and handle
      params <- append(params, list(...))
Guolin Ke's avatar
Guolin Ke committed
34
      params_str <- lgb.params2str(params)
Guolin Ke's avatar
Guolin Ke committed
35
      handle <- 0.0
36

37
38
      # Attempts to create a handle for the dataset
      try({
39

40
41
        # Check if training dataset is not null
        if (!is.null(train_set)) {
42

43
44
45
46
          # Check if training dataset is lgb.Dataset or not
          if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
47

48
          # Store booster handle
49
50
51
52
53
54
          handle <- lgb.call(
            "LGBM_BoosterCreate_R"
            , ret = handle
            , train_set$.__enclos_env__$private$get_handle()
            , params_str
          )
55

56
57
58
59
          # Create private booster information
          private$train_set <- train_set
          private$num_dataset <- 1
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
60

61
62
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
63

64
            # Merge booster
65
66
67
68
69
70
            lgb.call(
              "LGBM_BoosterMerge_R"
              , ret = NULL
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
71

72
          }
73

74
75
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
76

77
        } else if (!is.null(modelfile)) {
78

79
80
81
82
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
83

84
          # Create booster from model
85
86
87
88
89
          handle <- lgb.call(
            "LGBM_BoosterCreateFromModelfile_R"
            , ret = handle
            , lgb.c_str(modelfile)
          )
90

91
        } else if (!is.null(model_str)) {
92

93
          # Do we have a model_str as character?
94
95
96
          if (!is.character(model_str)) {
            stop("lgb.Booster: Can only use a string as model_str")
          }
97

98
          # Create booster from model
99
100
101
102
103
          handle <- lgb.call(
            "LGBM_BoosterLoadModelFromString_R"
            , ret = handle
            , lgb.c_str(model_str)
          )
104

105
        } else {
106

107
          # Booster non existent
108
109
110
111
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
112

113
        }
114

115
      })
116

117
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
Guolin Ke's avatar
Guolin Ke committed
118
      if (lgb.is.null.handle(handle)) {
119

Guolin Ke's avatar
Guolin Ke committed
120
        stop("lgb.Booster: cannot create Booster handle")
121

Guolin Ke's avatar
Guolin Ke committed
122
      } else {
123

Guolin Ke's avatar
Guolin Ke committed
124
125
126
127
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
128
129
130
131
132
        private$num_class <- lgb.call(
          "LGBM_BoosterGetNumClasses_R"
          , ret = private$num_class
          , private$handle
        )
133

Guolin Ke's avatar
Guolin Ke committed
134
      }
135

Guolin Ke's avatar
Guolin Ke committed
136
    },
137

138
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
139
    set_train_data_name = function(name) {
140

141
      # Set name
Guolin Ke's avatar
Guolin Ke committed
142
      private$name_train_set <- name
143
      return(invisible(self))
144

Guolin Ke's avatar
Guolin Ke committed
145
    },
146

147
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
148
    add_valid = function(data, name) {
149

150
      # Check if data is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
151
      if (!lgb.check.r6.class(data, "lgb.Dataset")) {
152
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
153
      }
154

155
      # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
156
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
157
158
159
160
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
161
      }
162

163
      # Check if names are character
164
165
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
166
      }
167

168
      # Add validation data to booster
169
170
171
172
173
174
      lgb.call(
        "LGBM_BoosterAddValidData_R"
        , ret = NULL
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
175

176
177
178
179
180
      # Store private information
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
      private$num_dataset <- private$num_dataset + 1
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
181

182
      # Return self
183
      return(invisible(self))
184

Guolin Ke's avatar
Guolin Ke committed
185
    },
186

187
    # Reset parameters of booster
Guolin Ke's avatar
Guolin Ke committed
188
    reset_parameter = function(params, ...) {
189

190
191
      # Append parameters
      params <- append(params, list(...))
192
      params_str <- lgb.params2str(params)
193

194
      # Reset parameters
195
196
197
198
199
200
      lgb.call(
        "LGBM_BoosterResetParameter_R"
        , ret = NULL
        , private$handle
        , params_str
      )
201

202
      # Return self
203
      return(invisible(self))
204

Guolin Ke's avatar
Guolin Ke committed
205
    },
206

207
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
208
    update = function(train_set = NULL, fobj = NULL) {
209

210
      # Check if training set is not null
Guolin Ke's avatar
Guolin Ke committed
211
      if (!is.null(train_set)) {
212

213
        # Check if training set is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
214
215
216
        if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
217

218
        # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
219
        if (!identical(train_set$predictor, private$init_predictor)) {
220
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
221
        }
222

223
        # Reset training data on booster
224
225
226
227
228
229
        lgb.call(
          "LGBM_BoosterResetTrainingData_R"
          , ret = NULL
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
230

231
        # Store private train set
Guolin Ke's avatar
Guolin Ke committed
232
        private$train_set = train_set
233

Guolin Ke's avatar
Guolin Ke committed
234
      }
235

236
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
237
      if (is.null(fobj)) {
238
239
240
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
241
        # Boost iteration from known objective
242
        ret <- lgb.call("LGBM_BoosterUpdateOneIter_R", ret = NULL, private$handle)
243

Guolin Ke's avatar
Guolin Ke committed
244
      } else {
245

246
247
248
249
        # Check if objective is function
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
250
        if (!private$set_objective_to_none) {
251
          self$reset_parameter(params = list(objective = "none"))
252
253
          private$set_objective_to_none = TRUE
        }
254
        # Perform objective calculation
Guolin Ke's avatar
Guolin Ke committed
255
        gpair <- fobj(private$inner_predict(1), private$train_set)
256

257
        # Check for gradient and hessian as list
258
        if (is.null(gpair$grad) || is.null(gpair$hess)){
259
          stop("lgb.Booster.update: custom objective should
260
261
            return a list with attributes (hess, grad)")
        }
262

263
        # Return custom boosting gradient/hessian
264
265
266
267
268
269
270
271
        ret <- lgb.call(
          "LGBM_BoosterUpdateOneIterCustom_R"
          , ret = NULL
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
272

Guolin Ke's avatar
Guolin Ke committed
273
      }
274

275
      # Loop through each iteration
276
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
277
278
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
279

280
      return(ret)
281

Guolin Ke's avatar
Guolin Ke committed
282
    },
283

284
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
285
    rollback_one_iter = function() {
286

287
      # Return one iteration behind
288
289
290
291
292
      lgb.call(
        "LGBM_BoosterRollbackOneIter_R"
        , ret = NULL
        , private$handle
      )
293

294
      # Loop through each iteration
295
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
296
297
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
298

299
      # Return self
300
      return(invisible(self))
301

Guolin Ke's avatar
Guolin Ke committed
302
    },
303

304
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
305
    current_iter = function() {
306

307
      cur_iter <- 0L
308
309
310
311
312
      lgb.call(
        "LGBM_BoosterGetCurrentIteration_R"
        , ret = cur_iter
        , private$handle
      )
313

Guolin Ke's avatar
Guolin Ke committed
314
    },
315

316
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
317
    eval = function(data, name, feval = NULL) {
318

319
      # Check if dataset is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
320
      if (!lgb.check.r6.class(data, "lgb.Dataset")) {
321
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
322
      }
323

324
      # Check for identical data
Guolin Ke's avatar
Guolin Ke committed
325
      data_idx <- 0
326
327
328
      if (identical(data, private$train_set)) {
        data_idx <- 1
      } else {
329

330
        # Check for validation data
331
        if (length(private$valid_sets) > 0) {
332

333
          # Loop through each validation set
334
          for (i in seq_along(private$valid_sets)) {
335

336
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
337
            if (identical(data, private$valid_sets[[i]])) {
338

339
              # Found identical data, skip
Guolin Ke's avatar
Guolin Ke committed
340
341
              data_idx <- i + 1
              break
342

Guolin Ke's avatar
Guolin Ke committed
343
            }
344

Guolin Ke's avatar
Guolin Ke committed
345
          }
346

Guolin Ke's avatar
Guolin Ke committed
347
        }
348

Guolin Ke's avatar
Guolin Ke committed
349
      }
350

351
      # Check if evaluation was not done
Guolin Ke's avatar
Guolin Ke committed
352
      if (data_idx == 0) {
353

354
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
355
356
        self$add_valid(data, name)
        data_idx <- private$num_dataset
357

Guolin Ke's avatar
Guolin Ke committed
358
      }
359

360
      # Evaluate data
361
      private$inner_eval(name, data_idx, feval)
362

Guolin Ke's avatar
Guolin Ke committed
363
    },
364

365
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
366
    eval_train = function(feval = NULL) {
367
      private$inner_eval(private$name_train_set, 1, feval)
Guolin Ke's avatar
Guolin Ke committed
368
    },
369

370
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
371
    eval_valid = function(feval = NULL) {
372

373
      # Create ret list
Guolin Ke's avatar
Guolin Ke committed
374
      ret = list()
375

376
377
378
379
      # Check if validation is empty
      if (length(private$valid_sets) <= 0) {
        return(ret)
      }
380

381
      # Loop through each validation set
382
      for (i in seq_along(private$valid_sets)) {
383
384
385
386
        ret <- append(
          x = ret
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1, feval)
        )
Guolin Ke's avatar
Guolin Ke committed
387
      }
388

389
390
      # Return ret
      return(ret)
391

Guolin Ke's avatar
Guolin Ke committed
392
    },
393

394
    # Save model
Guolin Ke's avatar
Guolin Ke committed
395
    save_model = function(filename, num_iteration = NULL) {
396

397
398
399
400
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
401

402
      # Save booster model
403
404
405
406
407
408
409
      lgb.call(
        "LGBM_BoosterSaveModel_R"
        , ret = NULL
        , private$handle
        , as.integer(num_iteration)
        , lgb.c_str(filename)
      )
410

411
      # Return self
412
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
413
    },
414

415
416
    # Save model to string
    save_model_to_string = function(num_iteration = NULL) {
417

418
419
420
421
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
422

423
      # Return model string
424
425
426
427
428
      return(lgb.call.return.str(
        "LGBM_BoosterSaveModelToString_R"
        , private$handle
        , as.integer(num_iteration)
      ))
429

430
    },
431

432
    # Dump model in memory
Guolin Ke's avatar
Guolin Ke committed
433
    dump_model = function(num_iteration = NULL) {
434

435
436
437
438
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
439

440
      # Return dumped model
441
442
443
444
445
      lgb.call.return.str(
        "LGBM_BoosterDumpModel_R"
        , private$handle
        , as.integer(num_iteration)
      )
446

Guolin Ke's avatar
Guolin Ke committed
447
    },
448

449
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
450
    predict = function(data,
451
452
453
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
454
                       predcontrib = FALSE,
455
                       header = FALSE,
456
                       reshape = FALSE, ...) {
457

458
459
460
461
      # Check if number of iteration is  non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
462

463
      # Predict on new data
464
      predictor <- Predictor$new(private$handle, ...)
465
      predictor$predict(data, num_iteration, rawscore, predleaf, predcontrib, header, reshape)
466

467
    },
468

469
470
471
    # Transform into predictor
    to_predictor = function() {
      Predictor$new(private$handle)
Guolin Ke's avatar
Guolin Ke committed
472
    },
473

474
    # Used for save
475
    raw = NA,
476

477
    # Save model to temporary file for in-memory saving
478
    save = function() {
479

480
      # Overwrite model in object
481
      self$raw <- self$save_model_to_string(NULL)
482

483
    }
484

Guolin Ke's avatar
Guolin Ke committed
485
486
  ),
  private = list(
487
488
489
490
491
492
493
494
495
496
497
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
    num_class = 1,
    num_dataset = 0,
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
498
    higher_better_inner_eval = NULL,
499
    set_objective_to_none = FALSE,
500
501
    # Predict data
    inner_predict = function(idx) {
502

503
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
504
      data_name <- private$name_train_set
505

506
507
508
509
      # Check for id bigger than 1
      if (idx > 1) {
        data_name <- private$name_valid_sets[[idx - 1]]
      }
510

511
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
512
513
514
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
515

516
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
517
      if (is.null(private$predict_buffer[[data_name]])) {
518

519
        # Store predictions
520
        npred <- 0L
521
522
523
524
525
526
        npred <- lgb.call(
          "LGBM_BoosterGetNumPredict_R"
          , ret = npred
          , private$handle
          , as.integer(idx - 1)
        )
527
        private$predict_buffer[[data_name]] <- numeric(npred)
528

Guolin Ke's avatar
Guolin Ke committed
529
      }
530

531
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
532
      if (!private$is_predicted_cur_iter[[idx]]) {
533

534
        # Use buffer
535
536
537
538
539
540
        private$predict_buffer[[data_name]] <- lgb.call(
          "LGBM_BoosterGetPredict_R"
          , ret = private$predict_buffer[[data_name]]
          , private$handle
          , as.integer(idx - 1)
        )
Guolin Ke's avatar
Guolin Ke committed
541
542
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
543

544
545
      # Return prediction buffer
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
546
    },
547

548
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
549
    get_eval_info = function() {
550

551
      # Check for evaluation names emptiness
Guolin Ke's avatar
Guolin Ke committed
552
      if (is.null(private$eval_names)) {
553

554
        # Get evaluation names
555
556
557
558
        names <- lgb.call.return.str(
          "LGBM_BoosterGetEvalNames_R"
          , private$handle
        )
559

560
        # Check names' length
561
        if (nchar(names) > 0) {
562

563
          # Parse and store privately names
Guolin Ke's avatar
Guolin Ke committed
564
565
          names <- strsplit(names, "\t")[[1]]
          private$eval_names <- names
566
          private$higher_better_inner_eval <- grepl("^ndcg|^map|^auc$", names)
567

Guolin Ke's avatar
Guolin Ke committed
568
        }
569

Guolin Ke's avatar
Guolin Ke committed
570
      }
571

572
573
      # Return evaluation names
      return(private$eval_names)
574

Guolin Ke's avatar
Guolin Ke committed
575
    },
576

577
    # Perform inner evaluation
Guolin Ke's avatar
Guolin Ke committed
578
    inner_eval = function(data_name, data_idx, feval = NULL) {
579

580
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
581
582
583
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
584

585
      # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
586
      private$get_eval_info()
587

588
      # Prepare return
Guolin Ke's avatar
Guolin Ke committed
589
      ret <- list()
590

591
      # Check evaluation names existence
Guolin Ke's avatar
Guolin Ke committed
592
      if (length(private$eval_names) > 0) {
593

594
595
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
596
597
598
599
600
601
        tmp_vals <- lgb.call(
          "LGBM_BoosterGetEval_R"
          , ret = tmp_vals
          , private$handle
          , as.integer(data_idx - 1)
        )
602

603
        # Loop through all evaluation names
604
        for (i in seq_along(private$eval_names)) {
605

606
607
608
609
610
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
611
          res$higher_better <- private$higher_better_inner_eval[i]
612
          ret <- append(ret, list(res))
613

Guolin Ke's avatar
Guolin Ke committed
614
        }
615

Guolin Ke's avatar
Guolin Ke committed
616
      }
617

618
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
619
      if (!is.null(feval)) {
620

621
        # Check if evaluation metric is a function
622
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
623
624
          stop("lgb.Booster.eval: feval should be a function")
        }
625

626
        # Prepare data
Guolin Ke's avatar
Guolin Ke committed
627
        data <- private$train_set
628

629
630
631
632
        # Check if data to assess is existing differently
        if (data_idx > 1) {
          data <- private$valid_sets[[data_idx - 1]]
        }
633

634
        # Perform function evaluation
635
        res <- feval(private$inner_predict(data_idx), data)
636

637
        # Check for name correctness
638
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
639
          stop("lgb.Booster.eval: custom eval function should return a
640
641
            list with attribute (name, value, higher_better)");
        }
642

643
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
644
        res$data_name <- data_name
645
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
646
      }
647

648
649
      # Return ret
      return(ret)
650

Guolin Ke's avatar
Guolin Ke committed
651
    }
652

Guolin Ke's avatar
Guolin Ke committed
653
654
655
656
657
  )
)


#' Predict method for LightGBM model
658
#'
Guolin Ke's avatar
Guolin Ke committed
659
#' Predicted values based on class \code{lgb.Booster}
660
#'
Guolin Ke's avatar
Guolin Ke committed
661
662
663
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
664
#' @param rawscore whether the prediction should be returned in the for of original untransformed
665
666
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
667
#' @param predleaf whether predict leaf index instead.
668
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
669
#' @param header only used for prediction for text file. True if text file has header
670
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
671
#'                prediction outputs per case.
James Lamb's avatar
James Lamb committed
672
673
#' @param ... Additional named arguments passed to the \code{predict()} method of
#'            the \code{lgb.Booster} object passed to \code{object}.
674
#' @return
Guolin Ke's avatar
Guolin Ke committed
675
#' For regression or binary classification, it returns a vector of length \code{nrows(data)}.
676
677
#' For multiclass classification, either a \code{num_class * nrows(data)} vector or
#' a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
Guolin Ke's avatar
Guolin Ke committed
678
#' the \code{reshape} value.
679
680
#'
#' When \code{predleaf = TRUE}, the output is a matrix object with the
Guolin Ke's avatar
Guolin Ke committed
681
#' number of columns corresponding to the number of trees.
682
#'
Guolin Ke's avatar
Guolin Ke committed
683
#' @examples
684
685
686
687
688
689
690
691
692
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
693
694
695
696
697
698
699
700
701
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
#'   , nrounds = 10
#'   , valids = valids
#'   , min_data = 1
#'   , learning_rate = 1
#'   , early_stopping_rounds = 5
#' )
702
#' preds <- predict(model, test$data)
703
#'
Guolin Ke's avatar
Guolin Ke committed
704
705
#' @rdname predict.lgb.Booster
#' @export
James Lamb's avatar
James Lamb committed
706
707
708
709
710
711
712
predict.lgb.Booster <- function(object,
                                data,
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
713
                                reshape = FALSE,
James Lamb's avatar
James Lamb committed
714
                                ...) {
715

716
  # Check booster existence
717
718
  if (!lgb.is.Booster(object)) {
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
719
  }
720

721
  # Return booster predictions
722
723
724
725
726
727
728
729
730
731
  object$predict(
    data
    , num_iteration
    , rawscore
    , predleaf
    , predcontrib
    , header
    , reshape
    , ...
  )
Guolin Ke's avatar
Guolin Ke committed
732
733
734
}

#' Load LightGBM model
735
#'
736
737
738
#' Load LightGBM model from saved model file or string
#' Load LightGBM takes in either a file path or model string
#' If both are provided, Load will default to loading from file
739
#'
Guolin Ke's avatar
Guolin Ke committed
740
#' @param filename path of model file
741
#' @param model_str a str containing the model
742
#'
743
#' @return lgb.Booster
744
#'
Guolin Ke's avatar
Guolin Ke committed
745
#' @examples
746
747
748
749
750
751
752
753
754
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
755
756
757
758
759
760
761
762
763
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
#'   , nrounds = 10
#'   , valids = valids
#'   , min_data = 1
#'   , learning_rate = 1
#'   , early_stopping_rounds = 5
#' )
764
#' lgb.save(model, "model.txt")
765
766
767
#' load_booster <- lgb.load(filename = "model.txt")
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
768
#'
769
#' @rdname lgb.load
Guolin Ke's avatar
Guolin Ke committed
770
#' @export
771
lgb.load <- function(filename = NULL, model_str = NULL){
772

773
774
775
  if (is.null(filename) && is.null(model_str)) {
    stop("lgb.load: either filename or model_str must be given")
  }
776

777
778
  # Load from filename
  if (!is.null(filename) && !is.character(filename)) {
779
780
    stop("lgb.load: filename should be character")
  }
781

782
  # Return new booster
783
  if (!is.null(filename) && !file.exists(filename)) stop("lgb.load: file does not exist for supplied filename")
784
  if (!is.null(filename)) return(invisible(Booster$new(modelfile = filename)))
785

786
787
788
  # Load from model_str
  if (!is.null(model_str) && !is.character(model_str)) {
    stop("lgb.load: model_str should be character")
789
  }
790
  # Return new booster
791
  if (!is.null(model_str)) return(invisible(Booster$new(model_str = model_str)))
792

Guolin Ke's avatar
Guolin Ke committed
793
794
795
}

#' Save LightGBM model
796
#'
Guolin Ke's avatar
Guolin Ke committed
797
#' Save LightGBM model
798
#'
Guolin Ke's avatar
Guolin Ke committed
799
800
801
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
802
#'
803
#' @return lgb.Booster
804
#'
Guolin Ke's avatar
Guolin Ke committed
805
#' @examples
806
807
808
809
810
811
812
813
814
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
815
816
817
818
819
820
821
822
823
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
#'   , nrounds = 10
#'   , valids = valids
#'   , min_data = 1
#'   , learning_rate = 1
#'   , early_stopping_rounds = 5
#' )
824
#' lgb.save(model, "model.txt")
825
#'
826
#' @rdname lgb.save
Guolin Ke's avatar
Guolin Ke committed
827
#' @export
828
lgb.save <- function(booster, filename, num_iteration = NULL){
829

830
831
832
833
  # Check if booster is booster
  if (!lgb.is.Booster(booster)) {
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
834

835
836
837
838
  # Check if file name is character
  if (!is.character(filename)) {
    stop("lgb.save: filename should be a character")
  }
839

840
  # Store booster
841
  invisible(booster$save_model(filename, num_iteration))
842

Guolin Ke's avatar
Guolin Ke committed
843
844
845
}

#' Dump LightGBM model to json
846
#'
Guolin Ke's avatar
Guolin Ke committed
847
#' Dump LightGBM model to json
848
#'
Guolin Ke's avatar
Guolin Ke committed
849
850
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
851
#'
Guolin Ke's avatar
Guolin Ke committed
852
#' @return json format of model
853
#'
Guolin Ke's avatar
Guolin Ke committed
854
#' @examples
855
856
857
858
859
860
861
862
863
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
864
865
866
867
868
869
870
871
872
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
#'   , nrounds = 10
#'   , valids = valids
#'   , min_data = 1
#'   , learning_rate = 1
#'   , early_stopping_rounds = 5
#' )
873
#' json_model <- lgb.dump(model)
874
#'
875
#' @rdname lgb.dump
Guolin Ke's avatar
Guolin Ke committed
876
#' @export
877
lgb.dump <- function(booster, num_iteration = NULL){
878

879
880
881
882
  # Check if booster is booster
  if (!lgb.is.Booster(booster)) {
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
883

884
  # Return booster at requested iteration
Guolin Ke's avatar
Guolin Ke committed
885
  booster$dump_model(num_iteration)
886

Guolin Ke's avatar
Guolin Ke committed
887
888
889
}

#' Get record evaluation result from booster
890
#'
Guolin Ke's avatar
Guolin Ke committed
891
892
893
894
895
896
#' Get record evaluation result from booster
#' @param booster Object of class \code{lgb.Booster}
#' @param data_name name of dataset
#' @param eval_name name of evaluation
#' @param iters iterations, NULL will return all
#' @param is_err TRUE will return evaluation error instead
897
#'
Guolin Ke's avatar
Guolin Ke committed
898
#' @return vector of evaluation result
899
#'
900
901
902
903
904
905
906
907
908
909
#' @examples
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
910
911
912
913
914
915
916
917
918
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
#'   , nrounds = 10
#'   , valids = valids
#'   , min_data = 1
#'   , learning_rate = 1
#'   , early_stopping_rounds = 5
#' )
919
#' lgb.get.eval.result(model, "test", "l2")
Guolin Ke's avatar
Guolin Ke committed
920
921
#' @rdname lgb.get.eval.result
#' @export
922
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
923

924
  # Check if booster is booster
925
926
  if (!lgb.is.Booster(booster)) {
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
927
  }
928

929
  # Check if data and evaluation name are characters or not
930
931
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
932
  }
933

934
  # Check if recorded evaluation is existing
935
  if (is.null(booster$record_evals[[data_name]])) {
Guolin Ke's avatar
Guolin Ke committed
936
937
    stop("lgb.get.eval.result: wrong data name")
  }
938

939
  # Check if evaluation result is existing
940
  if (is.null(booster$record_evals[[data_name]][[eval_name]])) {
Guolin Ke's avatar
Guolin Ke committed
941
942
    stop("lgb.get.eval.result: wrong eval name")
  }
943

944
  # Create result
Guolin Ke's avatar
Guolin Ke committed
945
  result <- booster$record_evals[[data_name]][[eval_name]]$eval
946

947
  # Check if error is requested
948
  if (is_err) {
Guolin Ke's avatar
Guolin Ke committed
949
950
    result <- booster$record_evals[[data_name]][[eval_name]]$eval_err
  }
951

952
  # Check if iteration is non existant
953
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
954
955
    return(as.numeric(result))
  }
956

957
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
958
959
960
  iters <- as.integer(iters)
  delta <- booster$record_evals$start_iter - 1
  iters <- iters - delta
961

962
  # Return requested result
963
  as.numeric(result[iters])
Guolin Ke's avatar
Guolin Ke committed
964
}