lgb.Booster.R 30 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
#' @importFrom R6 R6Class
Booster <- R6::R6Class(
3
  classname = "lgb.Booster",
4
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
5
  public = list(
6

7
    best_iter = -1L,
8
    best_score = NA_real_,
Guolin Ke's avatar
Guolin Ke committed
9
    record_evals = list(),
10

11
12
    # Finalize will free up the handles
    finalize = function() {
13

14
      # Check the need for freeing handle
15
      if (!lgb.is.null.handle(x = private$handle)) {
16

17
        # Freeing up handle
18
        lgb.call(fun_name = "LGBM_BoosterFree_R", ret = NULL, private$handle)
Guolin Ke's avatar
Guolin Ke committed
19
        private$handle <- NULL
20

Guolin Ke's avatar
Guolin Ke committed
21
      }
22

23
    },
24

25
26
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
27
28
                          train_set = NULL,
                          modelfile = NULL,
29
                          model_str = NULL,
Guolin Ke's avatar
Guolin Ke committed
30
                          ...) {
31

32
33
      # Create parameters and handle
      params <- append(params, list(...))
Guolin Ke's avatar
Guolin Ke committed
34
      handle <- lgb.null.handle()
35

36
37
      # Attempts to create a handle for the dataset
      try({
38

39
40
41
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
42
          if (!lgb.check.r6.class(object = train_set, name = "lgb.Dataset")) {
43
44
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
45
46
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
          params <- modifyList(params, train_set$get_params())
47
          params_str <- lgb.params2str(params = params)
48
          # Store booster handle
49
          handle <- lgb.call(
50
            fun_name = "LGBM_BoosterCreate_R"
51
            , ret = handle
52
            , train_set_handle
53
54
            , params_str
          )
55

56
57
          # Create private booster information
          private$train_set <- train_set
58
          private$train_set_version <- train_set$.__enclos_env__$private$version
59
          private$num_dataset <- 1L
60
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
61

62
63
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
64

65
            # Merge booster
66
            lgb.call(
67
              fun_name = "LGBM_BoosterMerge_R"
68
69
70
71
              , ret = NULL
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
72

73
          }
74

75
76
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
77

78
        } else if (!is.null(modelfile)) {
79

80
81
82
83
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
84

85
          # Create booster from model
86
          handle <- lgb.call(
87
            fun_name = "LGBM_BoosterCreateFromModelfile_R"
88
            , ret = handle
89
            , lgb.c_str(x = modelfile)
90
          )
91

92
        } else if (!is.null(model_str)) {
93

94
          # Do we have a model_str as character?
95
96
97
          if (!is.character(model_str)) {
            stop("lgb.Booster: Can only use a string as model_str")
          }
98

99
          # Create booster from model
100
          handle <- lgb.call(
101
            fun_name = "LGBM_BoosterLoadModelFromString_R"
102
            , ret = handle
103
            , lgb.c_str(x = model_str)
104
          )
105

106
        } else {
107

108
          # Booster non existent
109
110
111
112
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
113

114
        }
115

116
      })
117

118
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
119
      if (isTRUE(lgb.is.null.handle(x = handle))) {
120

Guolin Ke's avatar
Guolin Ke committed
121
        stop("lgb.Booster: cannot create Booster handle")
122

Guolin Ke's avatar
Guolin Ke committed
123
      } else {
124

Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
129
        private$num_class <- lgb.call(
130
          fun_name = "LGBM_BoosterGetNumClasses_R"
131
132
133
          , ret = private$num_class
          , private$handle
        )
134

Guolin Ke's avatar
Guolin Ke committed
135
      }
136

Guolin Ke's avatar
Guolin Ke committed
137
    },
138

139
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
140
    set_train_data_name = function(name) {
141

142
      # Set name
Guolin Ke's avatar
Guolin Ke committed
143
      private$name_train_set <- name
144
      return(invisible(self))
145

Guolin Ke's avatar
Guolin Ke committed
146
    },
147

148
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
149
    add_valid = function(data, name) {
150

151
      # Check if data is lgb.Dataset
152
      if (!lgb.check.r6.class(object = data, name = "lgb.Dataset")) {
153
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
154
      }
155

156
      # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
157
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
158
159
160
161
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
162
      }
163

164
      # Check if names are character
165
166
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
167
      }
168

169
      # Add validation data to booster
170
      lgb.call(
171
        fun_name = "LGBM_BoosterAddValidData_R"
172
173
174
175
        , ret = NULL
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
176

177
178
179
      # Store private information
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
180
      private$num_dataset <- private$num_dataset + 1L
181
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
182

183
      return(invisible(self))
184

Guolin Ke's avatar
Guolin Ke committed
185
    },
186

187
    # Reset parameters of booster
Guolin Ke's avatar
Guolin Ke committed
188
    reset_parameter = function(params, ...) {
189

190
191
      # Append parameters
      params <- append(params, list(...))
192
      params_str <- lgb.params2str(params = params)
193

194
      # Reset parameters
195
      lgb.call(
196
        fun_name = "LGBM_BoosterResetParameter_R"
197
198
199
200
        , ret = NULL
        , private$handle
        , params_str
      )
201

202
      return(invisible(self))
203

Guolin Ke's avatar
Guolin Ke committed
204
    },
205

206
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
207
    update = function(train_set = NULL, fobj = NULL) {
208

209
210
211
212
213
214
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

215
      # Check if training set is not null
Guolin Ke's avatar
Guolin Ke committed
216
      if (!is.null(train_set)) {
217

218
        # Check if training set is lgb.Dataset
219
        if (!lgb.check.r6.class(object = train_set, name = "lgb.Dataset")) {
Guolin Ke's avatar
Guolin Ke committed
220
221
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
222

223
        # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
224
        if (!identical(train_set$predictor, private$init_predictor)) {
225
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
226
        }
227

228
        # Reset training data on booster
229
        lgb.call(
230
          fun_name = "LGBM_BoosterResetTrainingData_R"
231
232
233
234
          , ret = NULL
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
235

236
        # Store private train set
237
        private$train_set <- train_set
238
        private$train_set_version <- train_set$.__enclos_env__$private$version
239

Guolin Ke's avatar
Guolin Ke committed
240
      }
241

242
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
243
      if (is.null(fobj)) {
244
245
246
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
247
        # Boost iteration from known objective
248
249
250
251
252
        ret <- lgb.call(
          fun_name = "LGBM_BoosterUpdateOneIter_R"
          , ret = NULL
          , private$handle
        )
253

Guolin Ke's avatar
Guolin Ke committed
254
      } else {
255

256
257
258
259
        # Check if objective is function
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
260
        if (!private$set_objective_to_none) {
261
          self$reset_parameter(params = list(objective = "none"))
262
          private$set_objective_to_none <- TRUE
263
        }
264
        # Perform objective calculation
265
        gpair <- fobj(private$inner_predict(1L), private$train_set)
266

267
        # Check for gradient and hessian as list
268
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
269
          stop("lgb.Booster.update: custom objective should
270
271
            return a list with attributes (hess, grad)")
        }
272

273
        # Return custom boosting gradient/hessian
274
        ret <- lgb.call(
275
          fun_name = "LGBM_BoosterUpdateOneIterCustom_R"
276
277
278
279
280
281
          , ret = NULL
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
282

Guolin Ke's avatar
Guolin Ke committed
283
      }
284

285
      # Loop through each iteration
286
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
287
288
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
289

290
      return(ret)
291

Guolin Ke's avatar
Guolin Ke committed
292
    },
293

294
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
295
    rollback_one_iter = function() {
296

297
      # Return one iteration behind
298
      lgb.call(
299
        fun_name = "LGBM_BoosterRollbackOneIter_R"
300
301
302
        , ret = NULL
        , private$handle
      )
303

304
      # Loop through each iteration
305
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
306
307
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
308

309
      return(invisible(self))
310

Guolin Ke's avatar
Guolin Ke committed
311
    },
312

313
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
314
    current_iter = function() {
315

316
      cur_iter <- 0L
317
      lgb.call(
318
        fun_name = "LGBM_BoosterGetCurrentIteration_R"
319
320
321
        , ret = cur_iter
        , private$handle
      )
322

Guolin Ke's avatar
Guolin Ke committed
323
    },
324

325
    # Get upper bound
326
    upper_bound = function() {
327

328
      upper_bound <- 0.0
329
      lgb.call(
330
        fun_name = "LGBM_BoosterGetUpperBoundValue_R"
331
332
333
334
335
336
337
        , ret = upper_bound
        , private$handle
      )

    },

    # Get lower bound
338
    lower_bound = function() {
339

340
      lower_bound <- 0.0
341
      lgb.call(
342
        fun_name = "LGBM_BoosterGetLowerBoundValue_R"
343
        , ret = lower_bound
344
345
346
347
348
        , private$handle
      )

    },

349
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
350
    eval = function(data, name, feval = NULL) {
351

352
      # Check if dataset is lgb.Dataset
353
      if (!lgb.check.r6.class(object = data, name = "lgb.Dataset")) {
354
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
355
      }
356

357
      # Check for identical data
358
      data_idx <- 0L
359
      if (identical(data, private$train_set)) {
360
        data_idx <- 1L
361
      } else {
362

363
        # Check for validation data
364
        if (length(private$valid_sets) > 0L) {
365

366
          # Loop through each validation set
367
          for (i in seq_along(private$valid_sets)) {
368

369
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
370
            if (identical(data, private$valid_sets[[i]])) {
371

372
              # Found identical data, skip
373
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
374
              break
375

Guolin Ke's avatar
Guolin Ke committed
376
            }
377

Guolin Ke's avatar
Guolin Ke committed
378
          }
379

Guolin Ke's avatar
Guolin Ke committed
380
        }
381

Guolin Ke's avatar
Guolin Ke committed
382
      }
383

384
      # Check if evaluation was not done
385
      if (data_idx == 0L) {
386

387
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
388
389
        self$add_valid(data, name)
        data_idx <- private$num_dataset
390

Guolin Ke's avatar
Guolin Ke committed
391
      }
392

393
      # Evaluate data
394
395
396
397
398
      private$inner_eval(
        data_name = name
        , data_idx = data_idx
        , feval = feval
      )
399

Guolin Ke's avatar
Guolin Ke committed
400
    },
401

402
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
403
    eval_train = function(feval = NULL) {
404
      private$inner_eval(private$name_train_set, 1L, feval)
Guolin Ke's avatar
Guolin Ke committed
405
    },
406

407
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
408
    eval_valid = function(feval = NULL) {
409

410
      # Create ret list
411
      ret <- list()
412

413
      # Check if validation is empty
414
      if (length(private$valid_sets) <= 0L) {
415
416
        return(ret)
      }
417

418
      # Loop through each validation set
419
      for (i in seq_along(private$valid_sets)) {
420
421
        ret <- append(
          x = ret
422
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
423
        )
Guolin Ke's avatar
Guolin Ke committed
424
      }
425

426
      return(ret)
427

Guolin Ke's avatar
Guolin Ke committed
428
    },
429

430
    # Save model
431
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
432

433
434
435
436
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
437

438
      # Save booster model
439
      lgb.call(
440
        fun_name = "LGBM_BoosterSaveModel_R"
441
442
443
        , ret = NULL
        , private$handle
        , as.integer(num_iteration)
444
        , as.integer(feature_importance_type)
445
        , lgb.c_str(x = filename)
446
      )
447

448
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
449
    },
450

451
    # Save model to string
452
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
453

454
455
456
457
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
458

459
      # Return model string
460
      return(lgb.call.return.str(
461
        fun_name = "LGBM_BoosterSaveModelToString_R"
462
463
        , private$handle
        , as.integer(num_iteration)
464
        , as.integer(feature_importance_type)
465
      ))
466

467
    },
468

469
    # Dump model in memory
470
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
471

472
473
474
475
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
476

477
      lgb.call.return.str(
478
        fun_name = "LGBM_BoosterDumpModel_R"
479
480
        , private$handle
        , as.integer(num_iteration)
481
        , as.integer(feature_importance_type)
482
      )
483

Guolin Ke's avatar
Guolin Ke committed
484
    },
485

486
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
487
    predict = function(data,
488
                       start_iteration = NULL,
489
490
491
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
492
                       predcontrib = FALSE,
493
                       header = FALSE,
494
                       reshape = FALSE, ...) {
495

496
      # Check if number of iteration is non existent
497
498
499
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
500
      # Check if start iteration is non existent
501
502
503
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
504

505
      # Predict on new data
506
      predictor <- Predictor$new(private$handle, ...)
507
508
509
510
511
512
513
514
515
516
      predictor$predict(
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
          , reshape = reshape
      )
517

518
    },
519

520
521
522
    # Transform into predictor
    to_predictor = function() {
      Predictor$new(private$handle)
Guolin Ke's avatar
Guolin Ke committed
523
    },
524

525
    # Used for save
526
    raw = NA,
527

528
    # Save model to temporary file for in-memory saving
529
    save = function() {
530

531
      # Overwrite model in object
532
      self$raw <- self$save_model_to_string(NULL)
533

534
    }
535

Guolin Ke's avatar
Guolin Ke committed
536
537
  ),
  private = list(
538
539
540
541
542
543
544
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
545
546
    num_class = 1L,
    num_dataset = 0L,
547
548
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
549
    higher_better_inner_eval = NULL,
550
    set_objective_to_none = FALSE,
551
    train_set_version = 0L,
552
553
    # Predict data
    inner_predict = function(idx) {
554

555
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
556
      data_name <- private$name_train_set
557

558
      # Check for id bigger than 1
559
560
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
561
      }
562

563
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
564
565
566
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
567

568
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
569
      if (is.null(private$predict_buffer[[data_name]])) {
570

571
        # Store predictions
572
        npred <- 0L
573
        npred <- lgb.call(
574
          fun_name = "LGBM_BoosterGetNumPredict_R"
575
576
          , ret = npred
          , private$handle
577
          , as.integer(idx - 1L)
578
        )
579
        private$predict_buffer[[data_name]] <- numeric(npred)
580

Guolin Ke's avatar
Guolin Ke committed
581
      }
582

583
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
584
      if (!private$is_predicted_cur_iter[[idx]]) {
585

586
        # Use buffer
587
        private$predict_buffer[[data_name]] <- lgb.call(
588
          fun_name = "LGBM_BoosterGetPredict_R"
589
590
          , ret = private$predict_buffer[[data_name]]
          , private$handle
591
          , as.integer(idx - 1L)
592
        )
Guolin Ke's avatar
Guolin Ke committed
593
594
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
595

596
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
597
    },
598

599
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
600
    get_eval_info = function() {
601

602
      # Check for evaluation names emptiness
Guolin Ke's avatar
Guolin Ke committed
603
      if (is.null(private$eval_names)) {
604

605
        # Get evaluation names
606
        names <- lgb.call.return.str(
607
          fun_name = "LGBM_BoosterGetEvalNames_R"
608
609
          , private$handle
        )
610

611
        # Check names' length
612
        if (nchar(names) > 0L) {
613

614
          # Parse and store privately names
615
          names <- strsplit(names, "\t")[[1L]]
Guolin Ke's avatar
Guolin Ke committed
616
          private$eval_names <- names
617
618
619
620
621

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
          metric_names <- gsub("@.*", "", names)
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
622

Guolin Ke's avatar
Guolin Ke committed
623
        }
624

Guolin Ke's avatar
Guolin Ke committed
625
      }
626

627
      return(private$eval_names)
628

Guolin Ke's avatar
Guolin Ke committed
629
    },
630

631
    # Perform inner evaluation
Guolin Ke's avatar
Guolin Ke committed
632
    inner_eval = function(data_name, data_idx, feval = NULL) {
633

634
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
635
636
637
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
638

639
      # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
640
      private$get_eval_info()
641

642
      # Prepare return
Guolin Ke's avatar
Guolin Ke committed
643
      ret <- list()
644

645
      # Check evaluation names existence
646
      if (length(private$eval_names) > 0L) {
647

648
649
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
650
        tmp_vals <- lgb.call(
651
          fun_name = "LGBM_BoosterGetEval_R"
652
653
          , ret = tmp_vals
          , private$handle
654
          , as.integer(data_idx - 1L)
655
        )
656

657
        # Loop through all evaluation names
658
        for (i in seq_along(private$eval_names)) {
659

660
661
662
663
664
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
665
          res$higher_better <- private$higher_better_inner_eval[i]
666
          ret <- append(ret, list(res))
667

Guolin Ke's avatar
Guolin Ke committed
668
        }
669

Guolin Ke's avatar
Guolin Ke committed
670
      }
671

672
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
673
      if (!is.null(feval)) {
674

675
        # Check if evaluation metric is a function
676
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
677
678
          stop("lgb.Booster.eval: feval should be a function")
        }
679

680
        # Prepare data
Guolin Ke's avatar
Guolin Ke committed
681
        data <- private$train_set
682

683
        # Check if data to assess is existing differently
684
685
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
686
        }
687

688
        # Perform function evaluation
689
        res <- feval(private$inner_predict(data_idx), data)
690

691
        # Check for name correctness
692
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
693
          stop("lgb.Booster.eval: custom eval function should return a
694
695
            list with attribute (name, value, higher_better)");
        }
696

697
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
698
        res$data_name <- data_name
699
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
700
      }
701

702
      return(ret)
703

Guolin Ke's avatar
Guolin Ke committed
704
    }
705

Guolin Ke's avatar
Guolin Ke committed
706
707
708
  )
)

709
710
711
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
712
713
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
714
715
716
717
718
719
720
721
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
722
#' @param rawscore whether the prediction should be returned in the for of original untransformed
723
724
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
725
#' @param predleaf whether predict leaf index instead.
726
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
727
#' @param header only used for prediction for text file. True if text file has header
728
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
729
#'                prediction outputs per case.
James Lamb's avatar
James Lamb committed
730
731
#' @param ... Additional named arguments passed to the \code{predict()} method of
#'            the \code{lgb.Booster} object passed to \code{object}.
732
733
734
735
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
736
#'
737
738
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
739
#'
Guolin Ke's avatar
Guolin Ke committed
740
#' @examples
741
#' \donttest{
742
743
744
745
746
747
748
749
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
750
751
752
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
753
#'   , nrounds = 5L
754
#'   , valids = valids
755
756
#'   , min_data = 1L
#'   , learning_rate = 1.0
757
#' )
758
#' preds <- predict(model, test$data)
759
#' }
Guolin Ke's avatar
Guolin Ke committed
760
#' @export
James Lamb's avatar
James Lamb committed
761
762
predict.lgb.Booster <- function(object,
                                data,
763
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
764
765
766
767
768
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
769
                                reshape = FALSE,
James Lamb's avatar
James Lamb committed
770
                                ...) {
771

772
  if (!lgb.is.Booster(x = object)) {
773
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
774
  }
775

776
  # Return booster predictions
777
  object$predict(
778
779
780
781
782
783
784
785
    data = data
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
      , reshape = reshape
786
787
    , ...
  )
Guolin Ke's avatar
Guolin Ke committed
788
789
}

790
791
792
793
#' @name lgb.load
#' @title Load LightGBM model
#' @description  Load LightGBM takes in either a file path or model string.
#'               If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
794
#' @param filename path of model file
795
#' @param model_str a str containing the model
796
#'
797
#' @return lgb.Booster
798
#'
Guolin Ke's avatar
Guolin Ke committed
799
#' @examples
800
#' \donttest{
801
802
803
804
805
806
807
808
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
809
810
811
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
812
#'   , nrounds = 5L
813
#'   , valids = valids
814
815
#'   , min_data = 1L
#'   , learning_rate = 1.0
816
#'   , early_stopping_rounds = 3L
817
#' )
818
819
820
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
821
822
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
823
#' }
Guolin Ke's avatar
Guolin Ke committed
824
#' @export
825
lgb.load <- function(filename = NULL, model_str = NULL) {
826

827
828
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
829

830
831
832
833
834
835
836
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
837
838
    return(invisible(Booster$new(modelfile = filename)))
  }
839

840
841
842
843
  if (model_str_provided) {
    if (!is.character(model_str)) {
      stop("lgb.load: model_str should be character")
    }
844
845
    return(invisible(Booster$new(model_str = model_str)))
  }
846

847
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
848
849
}

850
851
852
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
853
854
855
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
856
#'
857
#' @return lgb.Booster
858
#'
Guolin Ke's avatar
Guolin Ke committed
859
#' @examples
860
#' \donttest{
861
862
863
864
865
866
867
868
869
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
870
871
872
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
873
#'   , nrounds = 10L
874
#'   , valids = valids
875
876
877
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
878
#' )
879
#' lgb.save(model, tempfile(fileext = ".txt"))
880
#' }
Guolin Ke's avatar
Guolin Ke committed
881
#' @export
882
lgb.save <- function(booster, filename, num_iteration = NULL) {
883

884
  if (!lgb.is.Booster(x = booster)) {
885
886
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
887

888
889
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
890
  }
891

892
  # Store booster
893
894
895
896
  invisible(booster$save_model(
    filename = filename
    , num_iteration = num_iteration
  ))
897

Guolin Ke's avatar
Guolin Ke committed
898
899
}

900
901
902
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
903
904
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
905
#'
Guolin Ke's avatar
Guolin Ke committed
906
#' @return json format of model
907
#'
Guolin Ke's avatar
Guolin Ke committed
908
#' @examples
909
#' \donttest{
910
911
912
913
914
915
916
917
918
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
919
920
921
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
922
#'   , nrounds = 10L
923
#'   , valids = valids
924
925
926
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
927
#' )
928
#' json_model <- lgb.dump(model)
929
#' }
Guolin Ke's avatar
Guolin Ke committed
930
#' @export
931
lgb.dump <- function(booster, num_iteration = NULL) {
932

933
  if (!lgb.is.Booster(x = booster)) {
934
935
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
936

937
  # Return booster at requested iteration
938
  booster$dump_model(num_iteration =  num_iteration)
939

Guolin Ke's avatar
Guolin Ke committed
940
941
}

942
943
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
944
945
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
946
#' @param booster Object of class \code{lgb.Booster}
947
948
949
950
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
951
#' @param is_err TRUE will return evaluation error instead
952
#'
953
#' @return numeric vector of evaluation result
954
#'
955
#' @examples
956
#' \donttest{
957
#' # train a regression model
958
959
960
961
962
963
964
965
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
966
967
968
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
969
#'   , nrounds = 5L
970
#'   , valids = valids
971
972
#'   , min_data = 1L
#'   , learning_rate = 1.0
973
#' )
974
975
976
977
978
979
980
981
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
982
#' lgb.get.eval.result(model, "test", "l2")
983
#' }
Guolin Ke's avatar
Guolin Ke committed
984
#' @export
985
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
986

987
  # Check if booster is booster
988
  if (!lgb.is.Booster(x = booster)) {
989
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
990
  }
991

992
  # Check if data and evaluation name are characters or not
993
994
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
995
  }
996

997
998
999
1000
1001
1002
1003
1004
1005
1006
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1007
  }
1008

1009
  # Check if evaluation result is existing
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1021
1022
    stop("lgb.get.eval.result: wrong eval name")
  }
1023

1024
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1025

1026
  # Check if error is requested
1027
  if (is_err) {
1028
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1029
  }
1030

1031
  # Check if iteration is non existant
1032
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1033
1034
    return(as.numeric(result))
  }
1035

1036
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1037
  iters <- as.integer(iters)
1038
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1039
  iters <- iters - delta
1040

1041
  # Return requested result
1042
  as.numeric(result[iters])
Guolin Ke's avatar
Guolin Ke committed
1043
}