lgb.Booster.R 30.1 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
#' @importFrom R6 R6Class
Booster <- R6::R6Class(
3
  classname = "lgb.Booster",
4
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
5
  public = list(
6

7
    best_iter = -1L,
8
    best_score = NA_real_,
9
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
10
    record_evals = list(),
11

12
13
    # Finalize will free up the handles
    finalize = function() {
14

15
      # Check the need for freeing handle
16
      if (!lgb.is.null.handle(x = private$handle)) {
17

18
        # Freeing up handle
19
        lgb.call(fun_name = "LGBM_BoosterFree_R", ret = NULL, private$handle)
Guolin Ke's avatar
Guolin Ke committed
20
        private$handle <- NULL
21

Guolin Ke's avatar
Guolin Ke committed
22
      }
23

24
    },
25

26
27
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
28
29
                          train_set = NULL,
                          modelfile = NULL,
30
                          model_str = NULL,
Guolin Ke's avatar
Guolin Ke committed
31
                          ...) {
32

33
34
      # Create parameters and handle
      params <- append(params, list(...))
Guolin Ke's avatar
Guolin Ke committed
35
      handle <- lgb.null.handle()
36

37
38
      # Attempts to create a handle for the dataset
      try({
39

40
41
42
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
43
          if (!lgb.check.r6.class(object = train_set, name = "lgb.Dataset")) {
44
45
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
46
47
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
          params <- modifyList(params, train_set$get_params())
48
          params_str <- lgb.params2str(params = params)
49
          # Store booster handle
50
          handle <- lgb.call(
51
            fun_name = "LGBM_BoosterCreate_R"
52
            , ret = handle
53
            , train_set_handle
54
55
            , params_str
          )
56

57
58
          # Create private booster information
          private$train_set <- train_set
59
          private$train_set_version <- train_set$.__enclos_env__$private$version
60
          private$num_dataset <- 1L
61
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
62

63
64
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
65

66
            # Merge booster
67
            lgb.call(
68
              fun_name = "LGBM_BoosterMerge_R"
69
70
71
72
              , ret = NULL
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
73

74
          }
75

76
77
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
78

79
        } else if (!is.null(modelfile)) {
80

81
82
83
84
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
85

86
          # Create booster from model
87
          handle <- lgb.call(
88
            fun_name = "LGBM_BoosterCreateFromModelfile_R"
89
            , ret = handle
90
            , lgb.c_str(x = modelfile)
91
          )
92

93
        } else if (!is.null(model_str)) {
94

95
          # Do we have a model_str as character?
96
97
98
          if (!is.character(model_str)) {
            stop("lgb.Booster: Can only use a string as model_str")
          }
99

100
          # Create booster from model
101
          handle <- lgb.call(
102
            fun_name = "LGBM_BoosterLoadModelFromString_R"
103
            , ret = handle
104
            , lgb.c_str(x = model_str)
105
          )
106

107
        } else {
108

109
          # Booster non existent
110
111
112
113
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
114

115
        }
116

117
      })
118

119
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
120
      if (isTRUE(lgb.is.null.handle(x = handle))) {
121

Guolin Ke's avatar
Guolin Ke committed
122
        stop("lgb.Booster: cannot create Booster handle")
123

Guolin Ke's avatar
Guolin Ke committed
124
      } else {
125

Guolin Ke's avatar
Guolin Ke committed
126
127
128
129
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
130
        private$num_class <- lgb.call(
131
          fun_name = "LGBM_BoosterGetNumClasses_R"
132
133
134
          , ret = private$num_class
          , private$handle
        )
135

Guolin Ke's avatar
Guolin Ke committed
136
      }
137

138
139
      self$params <- params

Guolin Ke's avatar
Guolin Ke committed
140
    },
141

142
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
143
    set_train_data_name = function(name) {
144

145
      # Set name
Guolin Ke's avatar
Guolin Ke committed
146
      private$name_train_set <- name
147
      return(invisible(self))
148

Guolin Ke's avatar
Guolin Ke committed
149
    },
150

151
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
152
    add_valid = function(data, name) {
153

154
      # Check if data is lgb.Dataset
155
      if (!lgb.check.r6.class(object = data, name = "lgb.Dataset")) {
156
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
157
      }
158

159
      # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
160
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
161
162
163
164
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
165
      }
166

167
      # Check if names are character
168
169
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
170
      }
171

172
      # Add validation data to booster
173
      lgb.call(
174
        fun_name = "LGBM_BoosterAddValidData_R"
175
176
177
178
        , ret = NULL
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
179

180
181
182
      # Store private information
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
183
      private$num_dataset <- private$num_dataset + 1L
184
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
185

186
      return(invisible(self))
187

Guolin Ke's avatar
Guolin Ke committed
188
    },
189

190
    # Reset parameters of booster
Guolin Ke's avatar
Guolin Ke committed
191
    reset_parameter = function(params, ...) {
192

193
194
195
196
197
      if (methods::is(self$params, "list")) {
        params <- modifyList(self$params, params)
      }

      params <- modifyList(params, list(...))
198
      params_str <- lgb.params2str(params = params)
199

200
      lgb.call(
201
        fun_name = "LGBM_BoosterResetParameter_R"
202
203
204
205
        , ret = NULL
        , private$handle
        , params_str
      )
206
      self$params <- params
207

208
      return(invisible(self))
209

Guolin Ke's avatar
Guolin Ke committed
210
    },
211

212
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
213
    update = function(train_set = NULL, fobj = NULL) {
214

215
216
217
218
219
220
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

221
      # Check if training set is not null
Guolin Ke's avatar
Guolin Ke committed
222
      if (!is.null(train_set)) {
223

224
        # Check if training set is lgb.Dataset
225
        if (!lgb.check.r6.class(object = train_set, name = "lgb.Dataset")) {
Guolin Ke's avatar
Guolin Ke committed
226
227
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
228

229
        # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
230
        if (!identical(train_set$predictor, private$init_predictor)) {
231
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
232
        }
233

234
        # Reset training data on booster
235
        lgb.call(
236
          fun_name = "LGBM_BoosterResetTrainingData_R"
237
238
239
240
          , ret = NULL
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
241

242
        # Store private train set
243
        private$train_set <- train_set
244
        private$train_set_version <- train_set$.__enclos_env__$private$version
245

Guolin Ke's avatar
Guolin Ke committed
246
      }
247

248
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
249
      if (is.null(fobj)) {
250
251
252
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
253
        # Boost iteration from known objective
254
255
256
257
258
        ret <- lgb.call(
          fun_name = "LGBM_BoosterUpdateOneIter_R"
          , ret = NULL
          , private$handle
        )
259

Guolin Ke's avatar
Guolin Ke committed
260
      } else {
261

262
263
264
265
        # Check if objective is function
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
266
        if (!private$set_objective_to_none) {
267
          self$reset_parameter(params = list(objective = "none"))
268
          private$set_objective_to_none <- TRUE
269
        }
270
        # Perform objective calculation
271
        gpair <- fobj(private$inner_predict(1L), private$train_set)
272

273
        # Check for gradient and hessian as list
274
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
275
          stop("lgb.Booster.update: custom objective should
276
277
            return a list with attributes (hess, grad)")
        }
278

279
        # Return custom boosting gradient/hessian
280
        ret <- lgb.call(
281
          fun_name = "LGBM_BoosterUpdateOneIterCustom_R"
282
283
284
285
286
287
          , ret = NULL
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
288

Guolin Ke's avatar
Guolin Ke committed
289
      }
290

291
      # Loop through each iteration
292
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
293
294
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
295

296
      return(ret)
297

Guolin Ke's avatar
Guolin Ke committed
298
    },
299

300
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
301
    rollback_one_iter = function() {
302

303
      # Return one iteration behind
304
      lgb.call(
305
        fun_name = "LGBM_BoosterRollbackOneIter_R"
306
307
308
        , ret = NULL
        , private$handle
      )
309

310
      # Loop through each iteration
311
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
312
313
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
314

315
      return(invisible(self))
316

Guolin Ke's avatar
Guolin Ke committed
317
    },
318

319
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
320
    current_iter = function() {
321

322
      cur_iter <- 0L
323
      lgb.call(
324
        fun_name = "LGBM_BoosterGetCurrentIteration_R"
325
326
327
        , ret = cur_iter
        , private$handle
      )
328

Guolin Ke's avatar
Guolin Ke committed
329
    },
330

331
    # Get upper bound
332
    upper_bound = function() {
333

334
      upper_bound <- 0.0
335
      lgb.call(
336
        fun_name = "LGBM_BoosterGetUpperBoundValue_R"
337
338
339
340
341
342
343
        , ret = upper_bound
        , private$handle
      )

    },

    # Get lower bound
344
    lower_bound = function() {
345

346
      lower_bound <- 0.0
347
      lgb.call(
348
        fun_name = "LGBM_BoosterGetLowerBoundValue_R"
349
        , ret = lower_bound
350
351
352
353
354
        , private$handle
      )

    },

355
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
356
    eval = function(data, name, feval = NULL) {
357

358
      # Check if dataset is lgb.Dataset
359
      if (!lgb.check.r6.class(object = data, name = "lgb.Dataset")) {
360
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
361
      }
362

363
      # Check for identical data
364
      data_idx <- 0L
365
      if (identical(data, private$train_set)) {
366
        data_idx <- 1L
367
      } else {
368

369
        # Check for validation data
370
        if (length(private$valid_sets) > 0L) {
371

372
          # Loop through each validation set
373
          for (i in seq_along(private$valid_sets)) {
374

375
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
376
            if (identical(data, private$valid_sets[[i]])) {
377

378
              # Found identical data, skip
379
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
380
              break
381

Guolin Ke's avatar
Guolin Ke committed
382
            }
383

Guolin Ke's avatar
Guolin Ke committed
384
          }
385

Guolin Ke's avatar
Guolin Ke committed
386
        }
387

Guolin Ke's avatar
Guolin Ke committed
388
      }
389

390
      # Check if evaluation was not done
391
      if (data_idx == 0L) {
392

393
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
394
395
        self$add_valid(data, name)
        data_idx <- private$num_dataset
396

Guolin Ke's avatar
Guolin Ke committed
397
      }
398

399
      # Evaluate data
400
401
402
403
404
      private$inner_eval(
        data_name = name
        , data_idx = data_idx
        , feval = feval
      )
405

Guolin Ke's avatar
Guolin Ke committed
406
    },
407

408
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
409
    eval_train = function(feval = NULL) {
410
      private$inner_eval(private$name_train_set, 1L, feval)
Guolin Ke's avatar
Guolin Ke committed
411
    },
412

413
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
414
    eval_valid = function(feval = NULL) {
415

416
      # Create ret list
417
      ret <- list()
418

419
      # Check if validation is empty
420
      if (length(private$valid_sets) <= 0L) {
421
422
        return(ret)
      }
423

424
      # Loop through each validation set
425
      for (i in seq_along(private$valid_sets)) {
426
427
        ret <- append(
          x = ret
428
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
429
        )
Guolin Ke's avatar
Guolin Ke committed
430
      }
431

432
      return(ret)
433

Guolin Ke's avatar
Guolin Ke committed
434
    },
435

436
    # Save model
437
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
438

439
440
441
442
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
443

444
      # Save booster model
445
      lgb.call(
446
        fun_name = "LGBM_BoosterSaveModel_R"
447
448
449
        , ret = NULL
        , private$handle
        , as.integer(num_iteration)
450
        , as.integer(feature_importance_type)
451
        , lgb.c_str(x = filename)
452
      )
453

454
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
455
    },
456

457
    # Save model to string
458
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
459

460
461
462
463
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
464

465
      # Return model string
466
      return(lgb.call.return.str(
467
        fun_name = "LGBM_BoosterSaveModelToString_R"
468
469
        , private$handle
        , as.integer(num_iteration)
470
        , as.integer(feature_importance_type)
471
      ))
472

473
    },
474

475
    # Dump model in memory
476
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
477

478
479
480
481
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
482

483
      lgb.call.return.str(
484
        fun_name = "LGBM_BoosterDumpModel_R"
485
486
        , private$handle
        , as.integer(num_iteration)
487
        , as.integer(feature_importance_type)
488
      )
489

Guolin Ke's avatar
Guolin Ke committed
490
    },
491

492
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
493
    predict = function(data,
494
                       start_iteration = NULL,
495
496
497
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
498
                       predcontrib = FALSE,
499
                       header = FALSE,
500
                       reshape = FALSE, ...) {
501

502
      # Check if number of iteration is non existent
503
504
505
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
506
      # Check if start iteration is non existent
507
508
509
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
510

511
      # Predict on new data
512
      predictor <- Predictor$new(private$handle, ...)
513
514
515
516
517
518
519
520
521
522
      predictor$predict(
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
          , reshape = reshape
      )
523

524
    },
525

526
527
528
    # Transform into predictor
    to_predictor = function() {
      Predictor$new(private$handle)
Guolin Ke's avatar
Guolin Ke committed
529
    },
530

531
    # Used for save
532
    raw = NA,
533

534
    # Save model to temporary file for in-memory saving
535
    save = function() {
536

537
      # Overwrite model in object
538
      self$raw <- self$save_model_to_string(NULL)
539

540
    }
541

Guolin Ke's avatar
Guolin Ke committed
542
543
  ),
  private = list(
544
545
546
547
548
549
550
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
551
552
    num_class = 1L,
    num_dataset = 0L,
553
554
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
555
    higher_better_inner_eval = NULL,
556
    set_objective_to_none = FALSE,
557
    train_set_version = 0L,
558
559
    # Predict data
    inner_predict = function(idx) {
560

561
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
562
      data_name <- private$name_train_set
563

564
      # Check for id bigger than 1
565
566
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
567
      }
568

569
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
570
571
572
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
573

574
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
575
      if (is.null(private$predict_buffer[[data_name]])) {
576

577
        # Store predictions
578
        npred <- 0L
579
        npred <- lgb.call(
580
          fun_name = "LGBM_BoosterGetNumPredict_R"
581
582
          , ret = npred
          , private$handle
583
          , as.integer(idx - 1L)
584
        )
585
        private$predict_buffer[[data_name]] <- numeric(npred)
586

Guolin Ke's avatar
Guolin Ke committed
587
      }
588

589
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
590
      if (!private$is_predicted_cur_iter[[idx]]) {
591

592
        # Use buffer
593
        private$predict_buffer[[data_name]] <- lgb.call(
594
          fun_name = "LGBM_BoosterGetPredict_R"
595
596
          , ret = private$predict_buffer[[data_name]]
          , private$handle
597
          , as.integer(idx - 1L)
598
        )
Guolin Ke's avatar
Guolin Ke committed
599
600
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
601

602
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
603
    },
604

605
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
606
    get_eval_info = function() {
607

608
      # Check for evaluation names emptiness
Guolin Ke's avatar
Guolin Ke committed
609
      if (is.null(private$eval_names)) {
610

611
        # Get evaluation names
612
        names <- lgb.call.return.str(
613
          fun_name = "LGBM_BoosterGetEvalNames_R"
614
615
          , private$handle
        )
616

617
        # Check names' length
618
        if (nchar(names) > 0L) {
619

620
          # Parse and store privately names
621
          names <- strsplit(names, "\t")[[1L]]
Guolin Ke's avatar
Guolin Ke committed
622
          private$eval_names <- names
623
624
625
626
627

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
          metric_names <- gsub("@.*", "", names)
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
628

Guolin Ke's avatar
Guolin Ke committed
629
        }
630

Guolin Ke's avatar
Guolin Ke committed
631
      }
632

633
      return(private$eval_names)
634

Guolin Ke's avatar
Guolin Ke committed
635
    },
636

637
    # Perform inner evaluation
Guolin Ke's avatar
Guolin Ke committed
638
    inner_eval = function(data_name, data_idx, feval = NULL) {
639

640
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
641
642
643
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
644

645
      # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
646
      private$get_eval_info()
647

648
      # Prepare return
Guolin Ke's avatar
Guolin Ke committed
649
      ret <- list()
650

651
      # Check evaluation names existence
652
      if (length(private$eval_names) > 0L) {
653

654
655
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
656
        tmp_vals <- lgb.call(
657
          fun_name = "LGBM_BoosterGetEval_R"
658
659
          , ret = tmp_vals
          , private$handle
660
          , as.integer(data_idx - 1L)
661
        )
662

663
        # Loop through all evaluation names
664
        for (i in seq_along(private$eval_names)) {
665

666
667
668
669
670
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
671
          res$higher_better <- private$higher_better_inner_eval[i]
672
          ret <- append(ret, list(res))
673

Guolin Ke's avatar
Guolin Ke committed
674
        }
675

Guolin Ke's avatar
Guolin Ke committed
676
      }
677

678
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
679
      if (!is.null(feval)) {
680

681
        # Check if evaluation metric is a function
682
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
683
684
          stop("lgb.Booster.eval: feval should be a function")
        }
685

686
        # Prepare data
Guolin Ke's avatar
Guolin Ke committed
687
        data <- private$train_set
688

689
        # Check if data to assess is existing differently
690
691
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
692
        }
693

694
        # Perform function evaluation
695
        res <- feval(private$inner_predict(data_idx), data)
696

697
        # Check for name correctness
698
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
699
          stop("lgb.Booster.eval: custom eval function should return a
700
701
            list with attribute (name, value, higher_better)");
        }
702

703
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
704
        res$data_name <- data_name
705
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
706
      }
707

708
      return(ret)
709

Guolin Ke's avatar
Guolin Ke committed
710
    }
711

Guolin Ke's avatar
Guolin Ke committed
712
713
714
  )
)

715
716
717
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
718
719
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
720
721
722
723
724
725
726
727
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
728
#' @param rawscore whether the prediction should be returned in the for of original untransformed
729
730
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
731
#' @param predleaf whether predict leaf index instead.
732
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
733
#' @param header only used for prediction for text file. True if text file has header
734
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
735
#'                prediction outputs per case.
James Lamb's avatar
James Lamb committed
736
737
#' @param ... Additional named arguments passed to the \code{predict()} method of
#'            the \code{lgb.Booster} object passed to \code{object}.
738
739
740
741
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
742
#'
743
744
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
745
#'
Guolin Ke's avatar
Guolin Ke committed
746
#' @examples
747
#' \donttest{
748
749
750
751
752
753
754
755
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
756
757
758
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
759
#'   , nrounds = 5L
760
#'   , valids = valids
761
762
#'   , min_data = 1L
#'   , learning_rate = 1.0
763
#' )
764
#' preds <- predict(model, test$data)
765
#' }
Guolin Ke's avatar
Guolin Ke committed
766
#' @export
James Lamb's avatar
James Lamb committed
767
768
predict.lgb.Booster <- function(object,
                                data,
769
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
770
771
772
773
774
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
775
                                reshape = FALSE,
James Lamb's avatar
James Lamb committed
776
                                ...) {
777

778
  if (!lgb.is.Booster(x = object)) {
779
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
780
  }
781

782
  # Return booster predictions
783
  object$predict(
784
785
786
787
788
789
790
791
    data = data
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
      , reshape = reshape
792
793
    , ...
  )
Guolin Ke's avatar
Guolin Ke committed
794
795
}

796
797
798
799
#' @name lgb.load
#' @title Load LightGBM model
#' @description  Load LightGBM takes in either a file path or model string.
#'               If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
800
#' @param filename path of model file
801
#' @param model_str a str containing the model
802
#'
803
#' @return lgb.Booster
804
#'
Guolin Ke's avatar
Guolin Ke committed
805
#' @examples
806
#' \donttest{
807
808
809
810
811
812
813
814
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
815
816
817
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
818
#'   , nrounds = 5L
819
#'   , valids = valids
820
821
#'   , min_data = 1L
#'   , learning_rate = 1.0
822
#'   , early_stopping_rounds = 3L
823
#' )
824
825
826
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
827
828
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
829
#' }
Guolin Ke's avatar
Guolin Ke committed
830
#' @export
831
lgb.load <- function(filename = NULL, model_str = NULL) {
832

833
834
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
835

836
837
838
839
840
841
842
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
843
844
    return(invisible(Booster$new(modelfile = filename)))
  }
845

846
847
848
849
  if (model_str_provided) {
    if (!is.character(model_str)) {
      stop("lgb.load: model_str should be character")
    }
850
851
    return(invisible(Booster$new(model_str = model_str)))
  }
852

853
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
854
855
}

856
857
858
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
859
860
861
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
862
#'
863
#' @return lgb.Booster
864
#'
Guolin Ke's avatar
Guolin Ke committed
865
#' @examples
866
#' \donttest{
867
868
869
870
871
872
873
874
875
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
876
877
878
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
879
#'   , nrounds = 10L
880
#'   , valids = valids
881
882
883
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
884
#' )
885
#' lgb.save(model, tempfile(fileext = ".txt"))
886
#' }
Guolin Ke's avatar
Guolin Ke committed
887
#' @export
888
lgb.save <- function(booster, filename, num_iteration = NULL) {
889

890
  if (!lgb.is.Booster(x = booster)) {
891
892
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
893

894
895
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
896
  }
897

898
  # Store booster
899
900
901
902
  invisible(booster$save_model(
    filename = filename
    , num_iteration = num_iteration
  ))
903

Guolin Ke's avatar
Guolin Ke committed
904
905
}

906
907
908
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
909
910
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
911
#'
Guolin Ke's avatar
Guolin Ke committed
912
#' @return json format of model
913
#'
Guolin Ke's avatar
Guolin Ke committed
914
#' @examples
915
#' \donttest{
916
917
918
919
920
921
922
923
924
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
925
926
927
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
928
#'   , nrounds = 10L
929
#'   , valids = valids
930
931
932
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
933
#' )
934
#' json_model <- lgb.dump(model)
935
#' }
Guolin Ke's avatar
Guolin Ke committed
936
#' @export
937
lgb.dump <- function(booster, num_iteration = NULL) {
938

939
  if (!lgb.is.Booster(x = booster)) {
940
941
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
942

943
  # Return booster at requested iteration
944
  booster$dump_model(num_iteration =  num_iteration)
945

Guolin Ke's avatar
Guolin Ke committed
946
947
}

948
949
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
950
951
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
952
#' @param booster Object of class \code{lgb.Booster}
953
954
955
956
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
957
#' @param is_err TRUE will return evaluation error instead
958
#'
959
#' @return numeric vector of evaluation result
960
#'
961
#' @examples
962
#' \donttest{
963
#' # train a regression model
964
965
966
967
968
969
970
971
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
972
973
974
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
975
#'   , nrounds = 5L
976
#'   , valids = valids
977
978
#'   , min_data = 1L
#'   , learning_rate = 1.0
979
#' )
980
981
982
983
984
985
986
987
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
988
#' lgb.get.eval.result(model, "test", "l2")
989
#' }
Guolin Ke's avatar
Guolin Ke committed
990
#' @export
991
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
992

993
  # Check if booster is booster
994
  if (!lgb.is.Booster(x = booster)) {
995
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
996
  }
997

998
  # Check if data and evaluation name are characters or not
999
1000
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1001
  }
1002

1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1013
  }
1014

1015
  # Check if evaluation result is existing
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1027
1028
    stop("lgb.get.eval.result: wrong eval name")
  }
1029

1030
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1031

1032
  # Check if error is requested
1033
  if (is_err) {
1034
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1035
  }
1036

1037
  # Check if iteration is non existant
1038
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1039
1040
    return(as.numeric(result))
  }
1041

1042
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1043
  iters <- as.integer(iters)
1044
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1045
  iters <- iters - delta
1046

1047
  # Return requested result
1048
  as.numeric(result[iters])
Guolin Ke's avatar
Guolin Ke committed
1049
}