lgb.Booster.R 30.5 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
#' @importFrom R6 R6Class
Booster <- R6::R6Class(
3
  classname = "lgb.Booster",
4
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
5
  public = list(
6

7
    best_iter = -1L,
8
    best_score = NA_real_,
9
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
10
    record_evals = list(),
11

12
13
    # Finalize will free up the handles
    finalize = function() {
14

15
      # Check the need for freeing handle
16
      if (!lgb.is.null.handle(x = private$handle)) {
17

18
        # Freeing up handle
19
        lgb.call(fun_name = "LGBM_BoosterFree_R", ret = NULL, private$handle)
Guolin Ke's avatar
Guolin Ke committed
20
        private$handle <- NULL
21

Guolin Ke's avatar
Guolin Ke committed
22
      }
23

24
25
      return(invisible(NULL))

26
    },
27

28
29
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
30
31
                          train_set = NULL,
                          modelfile = NULL,
32
                          model_str = NULL,
Guolin Ke's avatar
Guolin Ke committed
33
                          ...) {
34

35
36
      # Create parameters and handle
      params <- append(params, list(...))
Guolin Ke's avatar
Guolin Ke committed
37
      handle <- lgb.null.handle()
38

39
40
      # Attempts to create a handle for the dataset
      try({
41

42
43
44
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
45
          if (!lgb.check.r6.class(object = train_set, name = "lgb.Dataset")) {
46
47
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
48
49
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
          params <- modifyList(params, train_set$get_params())
50
          params_str <- lgb.params2str(params = params)
51
          # Store booster handle
52
          handle <- lgb.call(
53
            fun_name = "LGBM_BoosterCreate_R"
54
            , ret = handle
55
            , train_set_handle
56
57
            , params_str
          )
58

59
60
          # Create private booster information
          private$train_set <- train_set
61
          private$train_set_version <- train_set$.__enclos_env__$private$version
62
          private$num_dataset <- 1L
63
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
64

65
66
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
67

68
            # Merge booster
69
            lgb.call(
70
              fun_name = "LGBM_BoosterMerge_R"
71
72
73
74
              , ret = NULL
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
75

76
          }
77

78
79
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
80

81
        } else if (!is.null(modelfile)) {
82

83
84
85
86
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
87

88
          # Create booster from model
89
          handle <- lgb.call(
90
            fun_name = "LGBM_BoosterCreateFromModelfile_R"
91
            , ret = handle
92
            , lgb.c_str(x = modelfile)
93
          )
94

95
        } else if (!is.null(model_str)) {
96

97
          # Do we have a model_str as character?
98
99
100
          if (!is.character(model_str)) {
            stop("lgb.Booster: Can only use a string as model_str")
          }
101

102
          # Create booster from model
103
          handle <- lgb.call(
104
            fun_name = "LGBM_BoosterLoadModelFromString_R"
105
            , ret = handle
106
            , lgb.c_str(x = model_str)
107
          )
108

109
        } else {
110

111
          # Booster non existent
112
113
114
115
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
116

117
        }
118

119
      })
120

121
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
122
      if (isTRUE(lgb.is.null.handle(x = handle))) {
123

Guolin Ke's avatar
Guolin Ke committed
124
        stop("lgb.Booster: cannot create Booster handle")
125

Guolin Ke's avatar
Guolin Ke committed
126
      } else {
127

Guolin Ke's avatar
Guolin Ke committed
128
129
130
131
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
132
        private$num_class <- lgb.call(
133
          fun_name = "LGBM_BoosterGetNumClasses_R"
134
135
136
          , ret = private$num_class
          , private$handle
        )
137

Guolin Ke's avatar
Guolin Ke committed
138
      }
139

140
141
      self$params <- params

142
143
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
144
    },
145

146
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
147
    set_train_data_name = function(name) {
148

149
      # Set name
Guolin Ke's avatar
Guolin Ke committed
150
      private$name_train_set <- name
151
      return(invisible(self))
152

Guolin Ke's avatar
Guolin Ke committed
153
    },
154

155
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
156
    add_valid = function(data, name) {
157

158
      # Check if data is lgb.Dataset
159
      if (!lgb.check.r6.class(object = data, name = "lgb.Dataset")) {
160
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
161
      }
162

163
      # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
164
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
165
166
167
168
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
169
      }
170

171
      # Check if names are character
172
173
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
174
      }
175

176
      # Add validation data to booster
177
      lgb.call(
178
        fun_name = "LGBM_BoosterAddValidData_R"
179
180
181
182
        , ret = NULL
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
183

184
185
186
      # Store private information
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
187
      private$num_dataset <- private$num_dataset + 1L
188
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
189

190
      return(invisible(self))
191

Guolin Ke's avatar
Guolin Ke committed
192
    },
193

194
    # Reset parameters of booster
Guolin Ke's avatar
Guolin Ke committed
195
    reset_parameter = function(params, ...) {
196

197
198
199
200
201
      if (methods::is(self$params, "list")) {
        params <- modifyList(self$params, params)
      }

      params <- modifyList(params, list(...))
202
      params_str <- lgb.params2str(params = params)
203

204
      lgb.call(
205
        fun_name = "LGBM_BoosterResetParameter_R"
206
207
208
209
        , ret = NULL
        , private$handle
        , params_str
      )
210
      self$params <- params
211

212
      return(invisible(self))
213

Guolin Ke's avatar
Guolin Ke committed
214
    },
215

216
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
217
    update = function(train_set = NULL, fobj = NULL) {
218

219
220
221
222
223
224
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

225
      # Check if training set is not null
Guolin Ke's avatar
Guolin Ke committed
226
      if (!is.null(train_set)) {
227

228
        # Check if training set is lgb.Dataset
229
        if (!lgb.check.r6.class(object = train_set, name = "lgb.Dataset")) {
Guolin Ke's avatar
Guolin Ke committed
230
231
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
232

233
        # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
234
        if (!identical(train_set$predictor, private$init_predictor)) {
235
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
236
        }
237

238
        # Reset training data on booster
239
        lgb.call(
240
          fun_name = "LGBM_BoosterResetTrainingData_R"
241
242
243
244
          , ret = NULL
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
245

246
        # Store private train set
247
        private$train_set <- train_set
248
        private$train_set_version <- train_set$.__enclos_env__$private$version
249

Guolin Ke's avatar
Guolin Ke committed
250
      }
251

252
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
253
      if (is.null(fobj)) {
254
255
256
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
257
        # Boost iteration from known objective
258
259
260
261
262
        ret <- lgb.call(
          fun_name = "LGBM_BoosterUpdateOneIter_R"
          , ret = NULL
          , private$handle
        )
263

Guolin Ke's avatar
Guolin Ke committed
264
      } else {
265

266
267
268
269
        # Check if objective is function
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
270
        if (!private$set_objective_to_none) {
271
          self$reset_parameter(params = list(objective = "none"))
272
          private$set_objective_to_none <- TRUE
273
        }
274
        # Perform objective calculation
275
        gpair <- fobj(private$inner_predict(1L), private$train_set)
276

277
        # Check for gradient and hessian as list
278
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
279
          stop("lgb.Booster.update: custom objective should
280
281
            return a list with attributes (hess, grad)")
        }
282

283
        # Return custom boosting gradient/hessian
284
        ret <- lgb.call(
285
          fun_name = "LGBM_BoosterUpdateOneIterCustom_R"
286
287
288
289
290
291
          , ret = NULL
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
292

Guolin Ke's avatar
Guolin Ke committed
293
      }
294

295
      # Loop through each iteration
296
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
297
298
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
299

300
      return(ret)
301

Guolin Ke's avatar
Guolin Ke committed
302
    },
303

304
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
305
    rollback_one_iter = function() {
306

307
      # Return one iteration behind
308
      lgb.call(
309
        fun_name = "LGBM_BoosterRollbackOneIter_R"
310
311
312
        , ret = NULL
        , private$handle
      )
313

314
      # Loop through each iteration
315
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
316
317
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
318

319
      return(invisible(self))
320

Guolin Ke's avatar
Guolin Ke committed
321
    },
322

323
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
324
    current_iter = function() {
325

326
      cur_iter <- 0L
327
328
329
330
331
332
      return(
        lgb.call(
          fun_name = "LGBM_BoosterGetCurrentIteration_R"
          , ret = cur_iter
          , private$handle
        )
333
      )
334

Guolin Ke's avatar
Guolin Ke committed
335
    },
336

337
    # Get upper bound
338
    upper_bound = function() {
339

340
      upper_bound <- 0.0
341
342
343
344
345
346
      return(
        lgb.call(
          fun_name = "LGBM_BoosterGetUpperBoundValue_R"
          , ret = upper_bound
          , private$handle
        )
347
348
349
350
351
      )

    },

    # Get lower bound
352
    lower_bound = function() {
353

354
      lower_bound <- 0.0
355
356
357
358
359
360
      return(
        lgb.call(
          fun_name = "LGBM_BoosterGetLowerBoundValue_R"
          , ret = lower_bound
          , private$handle
        )
361
362
363
364
      )

    },

365
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
366
    eval = function(data, name, feval = NULL) {
367

368
      # Check if dataset is lgb.Dataset
369
      if (!lgb.check.r6.class(object = data, name = "lgb.Dataset")) {
370
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
371
      }
372

373
      # Check for identical data
374
      data_idx <- 0L
375
      if (identical(data, private$train_set)) {
376
        data_idx <- 1L
377
      } else {
378

379
        # Check for validation data
380
        if (length(private$valid_sets) > 0L) {
381

382
          # Loop through each validation set
383
          for (i in seq_along(private$valid_sets)) {
384

385
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
386
            if (identical(data, private$valid_sets[[i]])) {
387

388
              # Found identical data, skip
389
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
390
              break
391

Guolin Ke's avatar
Guolin Ke committed
392
            }
393

Guolin Ke's avatar
Guolin Ke committed
394
          }
395

Guolin Ke's avatar
Guolin Ke committed
396
        }
397

Guolin Ke's avatar
Guolin Ke committed
398
      }
399

400
      # Check if evaluation was not done
401
      if (data_idx == 0L) {
402

403
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
404
405
        self$add_valid(data, name)
        data_idx <- private$num_dataset
406

Guolin Ke's avatar
Guolin Ke committed
407
      }
408

409
      # Evaluate data
410
411
412
413
414
415
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
416
      )
417

Guolin Ke's avatar
Guolin Ke committed
418
    },
419

420
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
421
    eval_train = function(feval = NULL) {
422
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
423
    },
424

425
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
426
    eval_valid = function(feval = NULL) {
427

428
      # Create ret list
429
      ret <- list()
430

431
      # Check if validation is empty
432
      if (length(private$valid_sets) <= 0L) {
433
434
        return(ret)
      }
435

436
      # Loop through each validation set
437
      for (i in seq_along(private$valid_sets)) {
438
439
        ret <- append(
          x = ret
440
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
441
        )
Guolin Ke's avatar
Guolin Ke committed
442
      }
443

444
      return(ret)
445

Guolin Ke's avatar
Guolin Ke committed
446
    },
447

448
    # Save model
449
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
450

451
452
453
454
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
455

456
      # Save booster model
457
      lgb.call(
458
        fun_name = "LGBM_BoosterSaveModel_R"
459
460
461
        , ret = NULL
        , private$handle
        , as.integer(num_iteration)
462
        , as.integer(feature_importance_type)
463
        , lgb.c_str(x = filename)
464
      )
465

466
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
467
    },
468

469
    # Save model to string
470
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
471

472
473
474
475
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
476

477
      # Return model string
478
479
480
481
482
483
484
485
      return(
        lgb.call.return.str(
          fun_name = "LGBM_BoosterSaveModelToString_R"
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
        )
      )
486

487
    },
488

489
    # Dump model in memory
490
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
491

492
493
494
495
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
496

497
498
499
500
501
502
503
      return(
        lgb.call.return.str(
          fun_name = "LGBM_BoosterDumpModel_R"
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
        )
504
      )
505

Guolin Ke's avatar
Guolin Ke committed
506
    },
507

508
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
509
    predict = function(data,
510
                       start_iteration = NULL,
511
512
513
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
514
                       predcontrib = FALSE,
515
                       header = FALSE,
516
                       reshape = FALSE, ...) {
517

518
      # Check if number of iteration is non existent
519
520
521
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
522
      # Check if start iteration is non existent
523
524
525
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
526

527
      # Predict on new data
528
      predictor <- Predictor$new(private$handle, ...)
529
530
      return(
        predictor$predict(
531
532
533
534
535
536
537
538
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
          , reshape = reshape
539
        )
540
      )
541

542
    },
543

544
545
    # Transform into predictor
    to_predictor = function() {
546
      return(Predictor$new(private$handle))
Guolin Ke's avatar
Guolin Ke committed
547
    },
548

549
    # Used for save
550
    raw = NA,
551

552
    # Save model to temporary file for in-memory saving
553
    save = function() {
554

555
      # Overwrite model in object
556
      self$raw <- self$save_model_to_string(NULL)
557

558
559
      return(invisible(NULL))

560
    }
561

Guolin Ke's avatar
Guolin Ke committed
562
563
  ),
  private = list(
564
565
566
567
568
569
570
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
571
572
    num_class = 1L,
    num_dataset = 0L,
573
574
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
575
    higher_better_inner_eval = NULL,
576
    set_objective_to_none = FALSE,
577
    train_set_version = 0L,
578
579
    # Predict data
    inner_predict = function(idx) {
580

581
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
582
      data_name <- private$name_train_set
583

584
      # Check for id bigger than 1
585
586
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
587
      }
588

589
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
590
591
592
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
593

594
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
595
      if (is.null(private$predict_buffer[[data_name]])) {
596

597
        # Store predictions
598
        npred <- 0L
599
        npred <- lgb.call(
600
          fun_name = "LGBM_BoosterGetNumPredict_R"
601
602
          , ret = npred
          , private$handle
603
          , as.integer(idx - 1L)
604
        )
605
        private$predict_buffer[[data_name]] <- numeric(npred)
606

Guolin Ke's avatar
Guolin Ke committed
607
      }
608

609
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
610
      if (!private$is_predicted_cur_iter[[idx]]) {
611

612
        # Use buffer
613
        private$predict_buffer[[data_name]] <- lgb.call(
614
          fun_name = "LGBM_BoosterGetPredict_R"
615
616
          , ret = private$predict_buffer[[data_name]]
          , private$handle
617
          , as.integer(idx - 1L)
618
        )
Guolin Ke's avatar
Guolin Ke committed
619
620
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
621

622
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
623
    },
624

625
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
626
    get_eval_info = function() {
627

628
      # Check for evaluation names emptiness
Guolin Ke's avatar
Guolin Ke committed
629
      if (is.null(private$eval_names)) {
630

631
        # Get evaluation names
632
        names <- lgb.call.return.str(
633
          fun_name = "LGBM_BoosterGetEvalNames_R"
634
635
          , private$handle
        )
636

637
        # Check names' length
638
        if (nchar(names) > 0L) {
639

640
          # Parse and store privately names
641
          names <- strsplit(names, "\t")[[1L]]
Guolin Ke's avatar
Guolin Ke committed
642
          private$eval_names <- names
643
644
645
646
647

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
          metric_names <- gsub("@.*", "", names)
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
648

Guolin Ke's avatar
Guolin Ke committed
649
        }
650

Guolin Ke's avatar
Guolin Ke committed
651
      }
652

653
      return(private$eval_names)
654

Guolin Ke's avatar
Guolin Ke committed
655
    },
656

657
    # Perform inner evaluation
Guolin Ke's avatar
Guolin Ke committed
658
    inner_eval = function(data_name, data_idx, feval = NULL) {
659

660
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
661
662
663
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
664

665
      # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
666
      private$get_eval_info()
667

668
      # Prepare return
Guolin Ke's avatar
Guolin Ke committed
669
      ret <- list()
670

671
      # Check evaluation names existence
672
      if (length(private$eval_names) > 0L) {
673

674
675
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
676
        tmp_vals <- lgb.call(
677
          fun_name = "LGBM_BoosterGetEval_R"
678
679
          , ret = tmp_vals
          , private$handle
680
          , as.integer(data_idx - 1L)
681
        )
682

683
        # Loop through all evaluation names
684
        for (i in seq_along(private$eval_names)) {
685

686
687
688
689
690
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
691
          res$higher_better <- private$higher_better_inner_eval[i]
692
          ret <- append(ret, list(res))
693

Guolin Ke's avatar
Guolin Ke committed
694
        }
695

Guolin Ke's avatar
Guolin Ke committed
696
      }
697

698
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
699
      if (!is.null(feval)) {
700

701
        # Check if evaluation metric is a function
702
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
703
704
          stop("lgb.Booster.eval: feval should be a function")
        }
705

706
        # Prepare data
Guolin Ke's avatar
Guolin Ke committed
707
        data <- private$train_set
708

709
        # Check if data to assess is existing differently
710
711
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
712
        }
713

714
        # Perform function evaluation
715
        res <- feval(private$inner_predict(data_idx), data)
716

717
        # Check for name correctness
718
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
719
          stop("lgb.Booster.eval: custom eval function should return a
720
721
            list with attribute (name, value, higher_better)");
        }
722

723
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
724
        res$data_name <- data_name
725
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
726
      }
727

728
      return(ret)
729

Guolin Ke's avatar
Guolin Ke committed
730
    }
731

Guolin Ke's avatar
Guolin Ke committed
732
733
734
  )
)

735
736
737
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
738
739
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
740
741
742
743
744
745
746
747
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
748
#' @param rawscore whether the prediction should be returned in the for of original untransformed
749
750
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
751
#' @param predleaf whether predict leaf index instead.
752
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
753
#' @param header only used for prediction for text file. True if text file has header
754
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
755
#'                prediction outputs per case.
James Lamb's avatar
James Lamb committed
756
757
#' @param ... Additional named arguments passed to the \code{predict()} method of
#'            the \code{lgb.Booster} object passed to \code{object}.
758
759
760
761
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
762
#'
763
764
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
765
#'
Guolin Ke's avatar
Guolin Ke committed
766
#' @examples
767
#' \donttest{
768
769
770
771
772
773
774
775
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
776
777
778
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
779
#'   , nrounds = 5L
780
#'   , valids = valids
781
782
#'   , min_data = 1L
#'   , learning_rate = 1.0
783
#' )
784
#' preds <- predict(model, test$data)
785
#' }
Guolin Ke's avatar
Guolin Ke committed
786
#' @export
James Lamb's avatar
James Lamb committed
787
788
predict.lgb.Booster <- function(object,
                                data,
789
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
790
791
792
793
794
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
795
                                reshape = FALSE,
James Lamb's avatar
James Lamb committed
796
                                ...) {
797

798
  if (!lgb.is.Booster(x = object)) {
799
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
800
  }
801

802
  # Return booster predictions
803
804
805
  return(
    object$predict(
      data = data
806
807
808
809
810
811
812
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
      , reshape = reshape
813
814
      , ...
    )
815
  )
Guolin Ke's avatar
Guolin Ke committed
816
817
}

818
819
820
821
#' @name lgb.load
#' @title Load LightGBM model
#' @description  Load LightGBM takes in either a file path or model string.
#'               If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
822
#' @param filename path of model file
823
#' @param model_str a str containing the model
824
#'
825
#' @return lgb.Booster
826
#'
Guolin Ke's avatar
Guolin Ke committed
827
#' @examples
828
#' \donttest{
829
830
831
832
833
834
835
836
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
837
838
839
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
840
#'   , nrounds = 5L
841
#'   , valids = valids
842
843
#'   , min_data = 1L
#'   , learning_rate = 1.0
844
#'   , early_stopping_rounds = 3L
845
#' )
846
847
848
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
849
850
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
851
#' }
Guolin Ke's avatar
Guolin Ke committed
852
#' @export
853
lgb.load <- function(filename = NULL, model_str = NULL) {
854

855
856
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
857

858
859
860
861
862
863
864
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
865
866
    return(invisible(Booster$new(modelfile = filename)))
  }
867

868
869
870
871
  if (model_str_provided) {
    if (!is.character(model_str)) {
      stop("lgb.load: model_str should be character")
    }
872
873
    return(invisible(Booster$new(model_str = model_str)))
  }
874

875
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
876
877
}

878
879
880
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
881
882
883
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
884
#'
885
#' @return lgb.Booster
886
#'
Guolin Ke's avatar
Guolin Ke committed
887
#' @examples
888
#' \donttest{
889
890
891
892
893
894
895
896
897
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
898
899
900
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
901
#'   , nrounds = 10L
902
#'   , valids = valids
903
904
905
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
906
#' )
907
#' lgb.save(model, tempfile(fileext = ".txt"))
908
#' }
Guolin Ke's avatar
Guolin Ke committed
909
#' @export
910
lgb.save <- function(booster, filename, num_iteration = NULL) {
911

912
  if (!lgb.is.Booster(x = booster)) {
913
914
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
915

916
917
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
918
  }
919

920
  # Store booster
921
922
923
924
925
926
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
927

Guolin Ke's avatar
Guolin Ke committed
928
929
}

930
931
932
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
933
934
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
935
#'
Guolin Ke's avatar
Guolin Ke committed
936
#' @return json format of model
937
#'
Guolin Ke's avatar
Guolin Ke committed
938
#' @examples
939
#' \donttest{
940
941
942
943
944
945
946
947
948
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
949
950
951
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
952
#'   , nrounds = 10L
953
#'   , valids = valids
954
955
956
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
957
#' )
958
#' json_model <- lgb.dump(model)
959
#' }
Guolin Ke's avatar
Guolin Ke committed
960
#' @export
961
lgb.dump <- function(booster, num_iteration = NULL) {
962

963
  if (!lgb.is.Booster(x = booster)) {
964
965
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
966

967
  # Return booster at requested iteration
968
  return(booster$dump_model(num_iteration =  num_iteration))
969

Guolin Ke's avatar
Guolin Ke committed
970
971
}

972
973
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
974
975
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
976
#' @param booster Object of class \code{lgb.Booster}
977
978
979
980
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
981
#' @param is_err TRUE will return evaluation error instead
982
#'
983
#' @return numeric vector of evaluation result
984
#'
985
#' @examples
986
#' \donttest{
987
#' # train a regression model
988
989
990
991
992
993
994
995
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
996
997
998
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
999
#'   , nrounds = 5L
1000
#'   , valids = valids
1001
1002
#'   , min_data = 1L
#'   , learning_rate = 1.0
1003
#' )
1004
1005
1006
1007
1008
1009
1010
1011
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1012
#' lgb.get.eval.result(model, "test", "l2")
1013
#' }
Guolin Ke's avatar
Guolin Ke committed
1014
#' @export
1015
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1016

1017
  # Check if booster is booster
1018
  if (!lgb.is.Booster(x = booster)) {
1019
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1020
  }
1021

1022
  # Check if data and evaluation name are characters or not
1023
1024
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1025
  }
1026

1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1037
  }
1038

1039
  # Check if evaluation result is existing
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1051
1052
    stop("lgb.get.eval.result: wrong eval name")
  }
1053

1054
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1055

1056
  # Check if error is requested
1057
  if (is_err) {
1058
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1059
  }
1060

1061
  # Check if iteration is non existant
1062
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1063
1064
    return(as.numeric(result))
  }
1065

1066
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1067
  iters <- as.integer(iters)
1068
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1069
  iters <- iters - delta
1070

1071
  # Return requested result
1072
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1073
}