lgb.Booster.R 26.6 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
#' @importFrom R6 R6Class
Booster <- R6::R6Class(
3
  classname = "lgb.Booster",
4
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
5
  public = list(
6

7
    best_iter = -1,
8
    best_score = NA,
Guolin Ke's avatar
Guolin Ke committed
9
    record_evals = list(),
10

11
12
    # Finalize will free up the handles
    finalize = function() {
13

14
      # Check the need for freeing handle
15
      if (!lgb.is.null.handle(private$handle)) {
16

17
        # Freeing up handle
18
        lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
Guolin Ke's avatar
Guolin Ke committed
19
        private$handle <- NULL
20

Guolin Ke's avatar
Guolin Ke committed
21
      }
22

23
    },
24

25
26
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
27
28
                          train_set = NULL,
                          modelfile = NULL,
29
                          model_str = NULL,
Guolin Ke's avatar
Guolin Ke committed
30
                          ...) {
31

32
33
      # Create parameters and handle
      params <- append(params, list(...))
Guolin Ke's avatar
Guolin Ke committed
34
      params_str <- lgb.params2str(params)
Guolin Ke's avatar
Guolin Ke committed
35
      handle <- 0.0
36

37
38
      # Attempts to create a handle for the dataset
      try({
39

40
41
        # Check if training dataset is not null
        if (!is.null(train_set)) {
42

43
44
45
46
          # Check if training dataset is lgb.Dataset or not
          if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
47

48
49
          # Store booster handle
          handle <- lgb.call("LGBM_BoosterCreate_R", ret = handle, train_set$.__enclos_env__$private$get_handle(), params_str)
50

51
52
53
54
          # Create private booster information
          private$train_set <- train_set
          private$num_dataset <- 1
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
55

56
57
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
58

59
60
61
62
63
            # Merge booster
            lgb.call("LGBM_BoosterMerge_R",
                     ret = NULL,
                     handle,
                     private$init_predictor$.__enclos_env__$private$handle)
64

65
          }
66

67
68
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
69

70
        } else if (!is.null(modelfile)) {
71

72
73
74
75
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
76

77
78
79
80
          # Create booster from model
          handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R",
                             ret = handle,
                             lgb.c_str(modelfile))
81

82
        } else if (!is.null(model_str)) {
83

84
          # Do we have a model_str as character?
85
86
87
          if (!is.character(model_str)) {
            stop("lgb.Booster: Can only use a string as model_str")
          }
88

89
90
91
92
          # Create booster from model
          handle <- lgb.call("LGBM_BoosterLoadModelFromString_R",
                             ret = handle,
                             lgb.c_str(model_str))
93

94
        } else {
95

96
97
          # Booster non existent
          stop("lgb.Booster: Need at least either training dataset, model file, or model_str to create booster instance")
98

99
        }
100

101
      })
102

103
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
Guolin Ke's avatar
Guolin Ke committed
104
      if (lgb.is.null.handle(handle)) {
105

Guolin Ke's avatar
Guolin Ke committed
106
        stop("lgb.Booster: cannot create Booster handle")
107

Guolin Ke's avatar
Guolin Ke committed
108
      } else {
109

Guolin Ke's avatar
Guolin Ke committed
110
111
112
113
114
115
116
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
        private$num_class <- lgb.call("LGBM_BoosterGetNumClasses_R",
                                      ret = private$num_class,
                                      private$handle)
117

Guolin Ke's avatar
Guolin Ke committed
118
      }
119

Guolin Ke's avatar
Guolin Ke committed
120
    },
121

122
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
123
    set_train_data_name = function(name) {
124

125
      # Set name
Guolin Ke's avatar
Guolin Ke committed
126
      private$name_train_set <- name
127
      return(invisible(self))
128

Guolin Ke's avatar
Guolin Ke committed
129
    },
130

131
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
132
    add_valid = function(data, name) {
133

134
      # Check if data is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
135
      if (!lgb.check.r6.class(data, "lgb.Dataset")) {
136
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
137
      }
138

139
      # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
140
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
141
        stop("lgb.Booster.add_valid: Failed to add validation data; you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
142
      }
143

144
      # Check if names are character
145
146
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
147
      }
148

149
150
151
152
153
      # Add validation data to booster
      lgb.call("LGBM_BoosterAddValidData_R",
               ret = NULL,
               private$handle,
               data$.__enclos_env__$private$get_handle())
154

155
156
157
158
159
      # Store private information
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
      private$num_dataset <- private$num_dataset + 1
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
160

161
      # Return self
162
      return(invisible(self))
163

Guolin Ke's avatar
Guolin Ke committed
164
    },
165

166
    # Reset parameters of booster
Guolin Ke's avatar
Guolin Ke committed
167
    reset_parameter = function(params, ...) {
168

169
170
      # Append parameters
      params <- append(params, list(...))
171
      params_str <- lgb.params2str(params)
172

173
174
175
176
177
      # Reset parameters
      lgb.call("LGBM_BoosterResetParameter_R",
               ret = NULL,
               private$handle,
               params_str)
178

179
      # Return self
180
      return(invisible(self))
181

Guolin Ke's avatar
Guolin Ke committed
182
    },
183

184
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
185
    update = function(train_set = NULL, fobj = NULL) {
186

187
      # Check if training set is not null
Guolin Ke's avatar
Guolin Ke committed
188
      if (!is.null(train_set)) {
189

190
        # Check if training set is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
191
192
193
        if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
194

195
        # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
196
        if (!identical(train_set$predictor, private$init_predictor)) {
197
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
198
        }
199

200
201
202
203
204
        # Reset training data on booster
        lgb.call("LGBM_BoosterResetTrainingData_R",
                 ret = NULL,
                 private$handle,
                 train_set$.__enclos_env__$private$get_handle())
205

206
        # Store private train set
Guolin Ke's avatar
Guolin Ke committed
207
        private$train_set = train_set
208

Guolin Ke's avatar
Guolin Ke committed
209
      }
210

211
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
212
      if (is.null(fobj)) {
213
214
215
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
216
        # Boost iteration from known objective
217
        ret <- lgb.call("LGBM_BoosterUpdateOneIter_R", ret = NULL, private$handle)
218

Guolin Ke's avatar
Guolin Ke committed
219
      } else {
220

221
222
223
224
        # Check if objective is function
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
225
        if (!private$set_objective_to_none) {
226
          self$reset_parameter(params = list(objective = "none"))
227
228
          private$set_objective_to_none = TRUE
        }
229
        # Perform objective calculation
Guolin Ke's avatar
Guolin Ke committed
230
        gpair <- fobj(private$inner_predict(1), private$train_set)
231

232
        # Check for gradient and hessian as list
233
        if(is.null(gpair$grad) || is.null(gpair$hess)){
234
          stop("lgb.Booster.update: custom objective should
235
236
            return a list with attributes (hess, grad)")
        }
237

238
239
240
241
242
243
244
        # Return custom boosting gradient/hessian
        ret <- lgb.call("LGBM_BoosterUpdateOneIterCustom_R",
                        ret = NULL,
                        private$handle,
                        gpair$grad,
                        gpair$hess,
                        length(gpair$grad))
245

Guolin Ke's avatar
Guolin Ke committed
246
      }
247

248
      # Loop through each iteration
249
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
250
251
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
252

253
      return(ret)
254

Guolin Ke's avatar
Guolin Ke committed
255
    },
256

257
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
258
    rollback_one_iter = function() {
259

260
261
262
263
      # Return one iteration behind
      lgb.call("LGBM_BoosterRollbackOneIter_R",
               ret = NULL,
               private$handle)
264

265
      # Loop through each iteration
266
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
267
268
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
269

270
      # Return self
271
      return(invisible(self))
272

Guolin Ke's avatar
Guolin Ke committed
273
    },
274

275
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
276
    current_iter = function() {
277

278
      cur_iter <- 0L
279
280
281
      lgb.call("LGBM_BoosterGetCurrentIteration_R",
               ret = cur_iter,
               private$handle)
282

Guolin Ke's avatar
Guolin Ke committed
283
    },
284

285
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
286
    eval = function(data, name, feval = NULL) {
287

288
      # Check if dataset is lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
289
      if (!lgb.check.r6.class(data, "lgb.Dataset")) {
290
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
291
      }
292

293
      # Check for identical data
Guolin Ke's avatar
Guolin Ke committed
294
      data_idx <- 0
295
296
297
      if (identical(data, private$train_set)) {
        data_idx <- 1
      } else {
298

299
        # Check for validation data
300
        if (length(private$valid_sets) > 0) {
301

302
          # Loop through each validation set
303
          for (i in seq_along(private$valid_sets)) {
304

305
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
306
            if (identical(data, private$valid_sets[[i]])) {
307

308
              # Found identical data, skip
Guolin Ke's avatar
Guolin Ke committed
309
310
              data_idx <- i + 1
              break
311

Guolin Ke's avatar
Guolin Ke committed
312
            }
313

Guolin Ke's avatar
Guolin Ke committed
314
          }
315

Guolin Ke's avatar
Guolin Ke committed
316
        }
317

Guolin Ke's avatar
Guolin Ke committed
318
      }
319

320
      # Check if evaluation was not done
Guolin Ke's avatar
Guolin Ke committed
321
      if (data_idx == 0) {
322

323
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
324
325
        self$add_valid(data, name)
        data_idx <- private$num_dataset
326

Guolin Ke's avatar
Guolin Ke committed
327
      }
328

329
      # Evaluate data
330
      private$inner_eval(name, data_idx, feval)
331

Guolin Ke's avatar
Guolin Ke committed
332
    },
333

334
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
335
    eval_train = function(feval = NULL) {
336
      private$inner_eval(private$name_train_set, 1, feval)
Guolin Ke's avatar
Guolin Ke committed
337
    },
338

339
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
340
    eval_valid = function(feval = NULL) {
341

342
      # Create ret list
Guolin Ke's avatar
Guolin Ke committed
343
      ret = list()
344

345
346
347
348
      # Check if validation is empty
      if (length(private$valid_sets) <= 0) {
        return(ret)
      }
349

350
      # Loop through each validation set
351
352
      for (i in seq_along(private$valid_sets)) {
        ret <- append(ret, private$inner_eval(private$name_valid_sets[[i]], i + 1, feval))
Guolin Ke's avatar
Guolin Ke committed
353
      }
354

355
356
      # Return ret
      return(ret)
357

Guolin Ke's avatar
Guolin Ke committed
358
    },
359

360
    # Save model
Guolin Ke's avatar
Guolin Ke committed
361
    save_model = function(filename, num_iteration = NULL) {
362

363
364
365
366
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
367

368
369
370
371
372
373
      # Save booster model
      lgb.call("LGBM_BoosterSaveModel_R",
               ret = NULL,
               private$handle,
               as.integer(num_iteration),
               lgb.c_str(filename))
374

375
      # Return self
376
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
377
    },
378

379
380
    # Save model to string
    save_model_to_string = function(num_iteration = NULL) {
381

382
383
384
385
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
386

387
388
      # Return model string
      return(lgb.call.return.str("LGBM_BoosterSaveModelToString_R",
389
390
                                 private$handle,
                                 as.integer(num_iteration)))
391

392
    },
393

394
    # Dump model in memory
Guolin Ke's avatar
Guolin Ke committed
395
    dump_model = function(num_iteration = NULL) {
396

397
398
399
400
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
401

402
403
404
405
      # Return dumped model
      lgb.call.return.str("LGBM_BoosterDumpModel_R",
                          private$handle,
                          as.integer(num_iteration))
406

Guolin Ke's avatar
Guolin Ke committed
407
    },
408

409
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
410
    predict = function(data,
411
412
413
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
414
                       predcontrib = FALSE,
415
                       header = FALSE,
416
                       reshape = FALSE, ...) {
417

418
419
420
421
      # Check if number of iteration is  non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
422

423
      # Predict on new data
424
      predictor <- Predictor$new(private$handle, ...)
425
      predictor$predict(data, num_iteration, rawscore, predleaf, predcontrib, header, reshape)
426

427
    },
428

429
430
431
    # Transform into predictor
    to_predictor = function() {
      Predictor$new(private$handle)
Guolin Ke's avatar
Guolin Ke committed
432
    },
433

434
    # Used for save
435
    raw = NA,
436

437
    # Save model to temporary file for in-memory saving
438
    save = function() {
439

440
      # Overwrite model in object
441
      self$raw <- self$save_model_to_string(NULL)
442

443
    }
444

Guolin Ke's avatar
Guolin Ke committed
445
446
  ),
  private = list(
447
448
449
450
451
452
453
454
455
456
457
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
    num_class = 1,
    num_dataset = 0,
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
458
    higher_better_inner_eval = NULL,
459
    set_objective_to_none = FALSE,
460
461
    # Predict data
    inner_predict = function(idx) {
462

463
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
464
      data_name <- private$name_train_set
465

466
467
468
469
      # Check for id bigger than 1
      if (idx > 1) {
        data_name <- private$name_valid_sets[[idx - 1]]
      }
470

471
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
472
473
474
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
475

476
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
477
      if (is.null(private$predict_buffer[[data_name]])) {
478

479
        # Store predictions
480
        npred <- 0L
481
        npred <- lgb.call("LGBM_BoosterGetNumPredict_R",
482
483
484
485
                          ret = npred,
                          private$handle,
                          as.integer(idx - 1))
        private$predict_buffer[[data_name]] <- numeric(npred)
486

Guolin Ke's avatar
Guolin Ke committed
487
      }
488

489
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
490
      if (!private$is_predicted_cur_iter[[idx]]) {
491

492
493
494
495
496
        # Use buffer
        private$predict_buffer[[data_name]] <- lgb.call("LGBM_BoosterGetPredict_R",
                                                        ret = private$predict_buffer[[data_name]],
                                                        private$handle,
                                                        as.integer(idx - 1))
Guolin Ke's avatar
Guolin Ke committed
497
498
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
499

500
501
      # Return prediction buffer
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
502
    },
503

504
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
505
    get_eval_info = function() {
506

507
      # Check for evaluation names emptiness
Guolin Ke's avatar
Guolin Ke committed
508
      if (is.null(private$eval_names)) {
509

510
511
512
        # Get evaluation names
        names <- lgb.call.return.str("LGBM_BoosterGetEvalNames_R",
                                     private$handle)
513

514
        # Check names' length
515
        if (nchar(names) > 0) {
516

517
          # Parse and store privately names
Guolin Ke's avatar
Guolin Ke committed
518
519
          names <- strsplit(names, "\t")[[1]]
          private$eval_names <- names
520
          private$higher_better_inner_eval <- grepl("^ndcg|^map|^auc$", names)
521

Guolin Ke's avatar
Guolin Ke committed
522
        }
523

Guolin Ke's avatar
Guolin Ke committed
524
      }
525

526
527
      # Return evaluation names
      return(private$eval_names)
528

Guolin Ke's avatar
Guolin Ke committed
529
    },
530

531
    # Perform inner evaluation
Guolin Ke's avatar
Guolin Ke committed
532
    inner_eval = function(data_name, data_idx, feval = NULL) {
533

534
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
535
536
537
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
538

539
      # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
540
      private$get_eval_info()
541

542
      # Prepare return
Guolin Ke's avatar
Guolin Ke committed
543
      ret <- list()
544

545
      # Check evaluation names existence
Guolin Ke's avatar
Guolin Ke committed
546
      if (length(private$eval_names) > 0) {
547

548
549
550
551
552
553
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
        tmp_vals <- lgb.call("LGBM_BoosterGetEval_R",
                             ret = tmp_vals,
                             private$handle,
                             as.integer(data_idx - 1))
554

555
        # Loop through all evaluation names
556
        for (i in seq_along(private$eval_names)) {
557

558
559
560
561
562
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
563
          res$higher_better <- private$higher_better_inner_eval[i]
564
          ret <- append(ret, list(res))
565

Guolin Ke's avatar
Guolin Ke committed
566
        }
567

Guolin Ke's avatar
Guolin Ke committed
568
      }
569

570
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
571
      if (!is.null(feval)) {
572

573
        # Check if evaluation metric is a function
574
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
575
576
          stop("lgb.Booster.eval: feval should be a function")
        }
577

578
        # Prepare data
Guolin Ke's avatar
Guolin Ke committed
579
        data <- private$train_set
580

581
582
583
584
        # Check if data to assess is existing differently
        if (data_idx > 1) {
          data <- private$valid_sets[[data_idx - 1]]
        }
585

586
        # Perform function evaluation
587
        res <- feval(private$inner_predict(data_idx), data)
588

589
        # Check for name correctness
590
        if(is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
591
          stop("lgb.Booster.eval: custom eval function should return a
592
593
            list with attribute (name, value, higher_better)");
        }
594

595
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
596
        res$data_name <- data_name
597
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
598
      }
599

600
601
      # Return ret
      return(ret)
602

Guolin Ke's avatar
Guolin Ke committed
603
    }
604

Guolin Ke's avatar
Guolin Ke committed
605
606
607
608
609
  )
)


#' Predict method for LightGBM model
610
#'
Guolin Ke's avatar
Guolin Ke committed
611
#' Predicted values based on class \code{lgb.Booster}
612
#'
Guolin Ke's avatar
Guolin Ke committed
613
614
615
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
616
617
#' @param rawscore whether the prediction should be returned in the for of original untransformed
#'        sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} for
Guolin Ke's avatar
Guolin Ke committed
618
#'        logistic regression would result in predictions for log-odds instead of probabilities.
619
#' @param predleaf whether predict leaf index instead.
620
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
621
#' @param header only used for prediction for text file. True if text file has header
622
623
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
#'        prediction outputs per case.
James Lamb's avatar
James Lamb committed
624
625
#' @param ... Additional named arguments passed to the \code{predict()} method of
#'            the \code{lgb.Booster} object passed to \code{object}.
626
#' @return
Guolin Ke's avatar
Guolin Ke committed
627
#' For regression or binary classification, it returns a vector of length \code{nrows(data)}.
628
629
#' For multiclass classification, either a \code{num_class * nrows(data)} vector or
#' a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
Guolin Ke's avatar
Guolin Ke committed
630
#' the \code{reshape} value.
631
632
#'
#' When \code{predleaf = TRUE}, the output is a matrix object with the
Guolin Ke's avatar
Guolin Ke committed
633
#' number of columns corresponding to the number of trees.
634
#'
Guolin Ke's avatar
Guolin Ke committed
635
#' @examples
636
637
638
639
640
641
642
643
644
645
646
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                    dtrain,
647
#'                    10,
648
649
650
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
651
#'                    early_stopping_rounds = 5)
652
#' preds <- predict(model, test$data)
653
#'
Guolin Ke's avatar
Guolin Ke committed
654
655
#' @rdname predict.lgb.Booster
#' @export
James Lamb's avatar
James Lamb committed
656
657
658
659
660
661
662
predict.lgb.Booster <- function(object,
                                data,
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
663
                                reshape = FALSE,
James Lamb's avatar
James Lamb committed
664
                                ...) {
665

666
  # Check booster existence
667
668
  if (!lgb.is.Booster(object)) {
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
669
  }
670

671
672
673
674
675
  # Return booster predictions
  object$predict(data,
                 num_iteration,
                 rawscore,
                 predleaf,
676
                 predcontrib,
677
                 header,
678
                 reshape, ...)
Guolin Ke's avatar
Guolin Ke committed
679
680
681
}

#' Load LightGBM model
682
#'
683
684
685
#' Load LightGBM model from saved model file or string
#' Load LightGBM takes in either a file path or model string
#' If both are provided, Load will default to loading from file
686
#'
Guolin Ke's avatar
Guolin Ke committed
687
#' @param filename path of model file
688
#' @param model_str a str containing the model
689
#'
690
#' @return lgb.Booster
691
#'
Guolin Ke's avatar
Guolin Ke committed
692
#' @examples
693
694
695
696
697
698
699
700
701
702
703
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                    dtrain,
704
#'                    10,
705
706
707
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
708
#'                    early_stopping_rounds = 5)
709
#' lgb.save(model, "model.txt")
710
711
712
#' load_booster <- lgb.load(filename = "model.txt")
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
713
#'
714
#' @rdname lgb.load
Guolin Ke's avatar
Guolin Ke committed
715
#' @export
716
lgb.load <- function(filename = NULL, model_str = NULL){
717

718
719
720
  if (is.null(filename) && is.null(model_str)) {
    stop("lgb.load: either filename or model_str must be given")
  }
721

722
723
  # Load from filename
  if (!is.null(filename) && !is.character(filename)) {
724
725
    stop("lgb.load: filename should be character")
  }
726

727
  # Return new booster
728
  if (!is.null(filename) && !file.exists(filename)) stop("lgb.load: file does not exist for supplied filename")
729
  if (!is.null(filename)) return(invisible(Booster$new(modelfile = filename)))
730

731
732
733
  # Load from model_str
  if (!is.null(model_str) && !is.character(model_str)) {
    stop("lgb.load: model_str should be character")
734
  }
735
  # Return new booster
736
  if (!is.null(model_str)) return(invisible(Booster$new(model_str = model_str)))
737

Guolin Ke's avatar
Guolin Ke committed
738
739
740
}

#' Save LightGBM model
741
#'
Guolin Ke's avatar
Guolin Ke committed
742
#' Save LightGBM model
743
#'
Guolin Ke's avatar
Guolin Ke committed
744
745
746
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
747
#'
748
#' @return lgb.Booster
749
#'
Guolin Ke's avatar
Guolin Ke committed
750
#' @examples
751
752
753
754
755
756
757
758
759
760
761
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                    dtrain,
762
#'                    10,
763
764
765
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
766
#'                    early_stopping_rounds = 5)
767
#' lgb.save(model, "model.txt")
768
#'
769
#' @rdname lgb.save
Guolin Ke's avatar
Guolin Ke committed
770
#' @export
771
lgb.save <- function(booster, filename, num_iteration = NULL){
772

773
774
775
776
  # Check if booster is booster
  if (!lgb.is.Booster(booster)) {
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
777

778
779
780
781
  # Check if file name is character
  if (!is.character(filename)) {
    stop("lgb.save: filename should be a character")
  }
782

783
  # Store booster
784
  invisible(booster$save_model(filename, num_iteration))
785

Guolin Ke's avatar
Guolin Ke committed
786
787
788
}

#' Dump LightGBM model to json
789
#'
Guolin Ke's avatar
Guolin Ke committed
790
#' Dump LightGBM model to json
791
#'
Guolin Ke's avatar
Guolin Ke committed
792
793
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
794
#'
Guolin Ke's avatar
Guolin Ke committed
795
#' @return json format of model
796
#'
Guolin Ke's avatar
Guolin Ke committed
797
#' @examples
798
799
800
801
802
803
804
805
806
807
808
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                   dtrain,
809
#'                    10,
810
811
812
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
813
#'                    early_stopping_rounds = 5)
814
#' json_model <- lgb.dump(model)
815
#'
816
#' @rdname lgb.dump
Guolin Ke's avatar
Guolin Ke committed
817
#' @export
818
lgb.dump <- function(booster, num_iteration = NULL){
819

820
821
822
823
  # Check if booster is booster
  if (!lgb.is.Booster(booster)) {
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
824

825
  # Return booster at requested iteration
Guolin Ke's avatar
Guolin Ke committed
826
  booster$dump_model(num_iteration)
827

Guolin Ke's avatar
Guolin Ke committed
828
829
830
}

#' Get record evaluation result from booster
831
#'
Guolin Ke's avatar
Guolin Ke committed
832
833
834
835
836
837
#' Get record evaluation result from booster
#' @param booster Object of class \code{lgb.Booster}
#' @param data_name name of dataset
#' @param eval_name name of evaluation
#' @param iters iterations, NULL will return all
#' @param is_err TRUE will return evaluation error instead
838
#'
Guolin Ke's avatar
Guolin Ke committed
839
#' @return vector of evaluation result
840
#'
841
842
843
844
845
846
847
848
849
850
851
852
#' @examples
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
#' model <- lgb.train(params,
#'                    dtrain,
853
#'                    10,
854
855
856
#'                    valids,
#'                    min_data = 1,
#'                    learning_rate = 1,
857
#'                    early_stopping_rounds = 5)
858
#' lgb.get.eval.result(model, "test", "l2")
Guolin Ke's avatar
Guolin Ke committed
859
860
#' @rdname lgb.get.eval.result
#' @export
861
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
862

863
  # Check if booster is booster
864
865
  if (!lgb.is.Booster(booster)) {
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
866
  }
867

868
  # Check if data and evaluation name are characters or not
869
870
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
871
  }
872

873
  # Check if recorded evaluation is existing
874
  if (is.null(booster$record_evals[[data_name]])) {
Guolin Ke's avatar
Guolin Ke committed
875
876
    stop("lgb.get.eval.result: wrong data name")
  }
877

878
  # Check if evaluation result is existing
879
  if (is.null(booster$record_evals[[data_name]][[eval_name]])) {
Guolin Ke's avatar
Guolin Ke committed
880
881
    stop("lgb.get.eval.result: wrong eval name")
  }
882

883
  # Create result
Guolin Ke's avatar
Guolin Ke committed
884
  result <- booster$record_evals[[data_name]][[eval_name]]$eval
885

886
  # Check if error is requested
887
  if (is_err) {
Guolin Ke's avatar
Guolin Ke committed
888
889
    result <- booster$record_evals[[data_name]][[eval_name]]$eval_err
  }
890

891
  # Check if iteration is non existant
892
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
893
894
    return(as.numeric(result))
  }
895

896
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
897
898
899
  iters <- as.integer(iters)
  delta <- booster$record_evals$start_iter - 1
  iters <- iters - delta
900

901
  # Return requested result
902
  as.numeric(result[iters])
Guolin Ke's avatar
Guolin Ke committed
903
}