lgb.Booster.R 32.4 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
#' @importFrom R6 R6Class
Booster <- R6::R6Class(
3
  classname = "lgb.Booster",
4
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
5
  public = list(
6

7
    best_iter = -1L,
8
    best_score = NA_real_,
9
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
10
    record_evals = list(),
11

12
13
    # Finalize will free up the handles
    finalize = function() {
14

15
      # Check the need for freeing handle
16
      if (!lgb.is.null.handle(x = private$handle)) {
17

18
        # Freeing up handle
19
20
21
22
23
24
        call_state <- 0L
        .Call(
          LGBM_BoosterFree_R
          , private$handle
          , call_state
        )
Guolin Ke's avatar
Guolin Ke committed
25
        private$handle <- NULL
26

Guolin Ke's avatar
Guolin Ke committed
27
      }
28

29
30
      return(invisible(NULL))

31
    },
32

33
34
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
35
36
                          train_set = NULL,
                          modelfile = NULL,
37
                          model_str = NULL,
Guolin Ke's avatar
Guolin Ke committed
38
                          ...) {
39

40
41
      # Create parameters and handle
      params <- append(params, list(...))
Guolin Ke's avatar
Guolin Ke committed
42
      handle <- lgb.null.handle()
43

44
45
      # Attempts to create a handle for the dataset
      try({
46

47
48
49
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
50
          if (!lgb.check.r6.class(object = train_set, name = "lgb.Dataset")) {
51
52
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
53
54
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
          params <- modifyList(params, train_set$get_params())
55
          params_str <- lgb.params2str(params = params)
56
          # Store booster handle
57
58
59
          call_state <- 0L
          .Call(
            LGBM_BoosterCreate_R
60
            , train_set_handle
61
            , params_str
62
63
            , handle
            , call_state
64
          )
65

66
67
          # Create private booster information
          private$train_set <- train_set
68
          private$train_set_version <- train_set$.__enclos_env__$private$version
69
          private$num_dataset <- 1L
70
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
71

72
73
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
74

75
            # Merge booster
76
77
78
            call_state <- 0L
            .Call(
              LGBM_BoosterMerge_R
79
80
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
81
              , call_state
82
            )
83

84
          }
85

86
87
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
88

89
        } else if (!is.null(modelfile)) {
90

91
92
93
94
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
95

96
          # Create booster from model
97
98
99
          call_state <- 0L
          .Call(
            LGBM_BoosterCreateFromModelfile_R
100
            , lgb.c_str(x = modelfile)
101
102
            , handle
            , call_state
103
          )
104

105
        } else if (!is.null(model_str)) {
106

107
          # Do we have a model_str as character?
108
109
110
          if (!is.character(model_str)) {
            stop("lgb.Booster: Can only use a string as model_str")
          }
111

112
          # Create booster from model
113
114
115
          call_state <- 0L
          .Call(
            LGBM_BoosterLoadModelFromString_R
116
            , lgb.c_str(x = model_str)
117
118
            , handle
            , call_state
119
          )
120

121
        } else {
122

123
          # Booster non existent
124
125
126
127
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
128

129
        }
130

131
      })
132

133
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
134
      if (isTRUE(lgb.is.null.handle(x = handle))) {
135

Guolin Ke's avatar
Guolin Ke committed
136
        stop("lgb.Booster: cannot create Booster handle")
137

Guolin Ke's avatar
Guolin Ke committed
138
      } else {
139

Guolin Ke's avatar
Guolin Ke committed
140
141
142
143
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
144
145
146
        call_state <- 0L
        .Call(
          LGBM_BoosterGetNumClasses_R
147
          , private$handle
148
149
          , private$num_class
          , call_state
150
        )
151

Guolin Ke's avatar
Guolin Ke committed
152
      }
153

154
155
      self$params <- params

156
157
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
158
    },
159

160
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
161
    set_train_data_name = function(name) {
162

163
      # Set name
Guolin Ke's avatar
Guolin Ke committed
164
      private$name_train_set <- name
165
      return(invisible(self))
166

Guolin Ke's avatar
Guolin Ke committed
167
    },
168

169
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
170
    add_valid = function(data, name) {
171

172
      # Check if data is lgb.Dataset
173
      if (!lgb.check.r6.class(object = data, name = "lgb.Dataset")) {
174
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
175
      }
176

177
      # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
178
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
179
180
181
182
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
183
      }
184

185
      # Check if names are character
186
187
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
188
      }
189

190
      # Add validation data to booster
191
192
193
      call_state <- 0L
      .Call(
        LGBM_BoosterAddValidData_R
194
195
        , private$handle
        , data$.__enclos_env__$private$get_handle()
196
        , call_state
197
      )
198

199
200
201
      # Store private information
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
202
      private$num_dataset <- private$num_dataset + 1L
203
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
204

205
      return(invisible(self))
206

Guolin Ke's avatar
Guolin Ke committed
207
    },
208

209
    # Reset parameters of booster
Guolin Ke's avatar
Guolin Ke committed
210
    reset_parameter = function(params, ...) {
211

212
213
214
215
216
      if (methods::is(self$params, "list")) {
        params <- modifyList(self$params, params)
      }

      params <- modifyList(params, list(...))
217
      params_str <- lgb.params2str(params = params)
218

219
220
221
      call_state <- 0L
      .Call(
        LGBM_BoosterResetParameter_R
222
223
        , private$handle
        , params_str
224
        , call_state
225
      )
226
      self$params <- params
227

228
      return(invisible(self))
229

Guolin Ke's avatar
Guolin Ke committed
230
    },
231

232
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
233
    update = function(train_set = NULL, fobj = NULL) {
234

235
236
237
238
239
240
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

241
      # Check if training set is not null
Guolin Ke's avatar
Guolin Ke committed
242
      if (!is.null(train_set)) {
243

244
        # Check if training set is lgb.Dataset
245
        if (!lgb.check.r6.class(object = train_set, name = "lgb.Dataset")) {
Guolin Ke's avatar
Guolin Ke committed
246
247
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
248

249
        # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
250
        if (!identical(train_set$predictor, private$init_predictor)) {
251
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
252
        }
253

254
        # Reset training data on booster
255
256
257
        call_state <- 0L
        .Call(
          LGBM_BoosterResetTrainingData_R
258
259
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
260
          , call_state
261
        )
262

263
        # Store private train set
264
        private$train_set <- train_set
265
        private$train_set_version <- train_set$.__enclos_env__$private$version
266

Guolin Ke's avatar
Guolin Ke committed
267
      }
268

269
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
270
      if (is.null(fobj)) {
271
272
273
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
274
        # Boost iteration from known objective
275
276
277
        call_state <- 0L
        .Call(
          LGBM_BoosterUpdateOneIter_R
278
          , private$handle
279
          , call_state
280
        )
281

Guolin Ke's avatar
Guolin Ke committed
282
      } else {
283

284
285
286
287
        # Check if objective is function
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
288
        if (!private$set_objective_to_none) {
289
          self$reset_parameter(params = list(objective = "none"))
290
          private$set_objective_to_none <- TRUE
291
        }
292
        # Perform objective calculation
293
        gpair <- fobj(private$inner_predict(1L), private$train_set)
294

295
        # Check for gradient and hessian as list
296
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
297
          stop("lgb.Booster.update: custom objective should
298
299
            return a list with attributes (hess, grad)")
        }
300

301
        # Return custom boosting gradient/hessian
302
303
304
        call_state <- 0L
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
305
306
307
308
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
309
          , call_state
310
        )
311

Guolin Ke's avatar
Guolin Ke committed
312
      }
313

314
      # Loop through each iteration
315
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
316
317
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
318

319
      return(invisible(self))
320

Guolin Ke's avatar
Guolin Ke committed
321
    },
322

323
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
324
    rollback_one_iter = function() {
325

326
      # Return one iteration behind
327
328
329
      call_state <- 0L
      .Call(
        LGBM_BoosterRollbackOneIter_R
330
        , private$handle
331
        , call_state
332
      )
333

334
      # Loop through each iteration
335
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
336
337
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
338

339
      return(invisible(self))
340

Guolin Ke's avatar
Guolin Ke committed
341
    },
342

343
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
344
    current_iter = function() {
345

346
      cur_iter <- 0L
347
348
349
350
351
352
      call_state <- 0L
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
        , call_state
353
      )
354
      return(cur_iter)
355

Guolin Ke's avatar
Guolin Ke committed
356
    },
357

358
    # Get upper bound
359
    upper_bound = function() {
360

361
      upper_bound <- 0.0
362
363
364
365
366
367
      call_state <- 0L
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
        , call_state
368
      )
369
      return(upper_bound)
370
371
372
373

    },

    # Get lower bound
374
    lower_bound = function() {
375

376
      lower_bound <- 0.0
377
378
379
380
381
382
      call_state <- 0L
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
        , call_state
383
      )
384
      return(lower_bound)
385
386
387

    },

388
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
389
    eval = function(data, name, feval = NULL) {
390

391
      # Check if dataset is lgb.Dataset
392
      if (!lgb.check.r6.class(object = data, name = "lgb.Dataset")) {
393
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
394
      }
395

396
      # Check for identical data
397
      data_idx <- 0L
398
      if (identical(data, private$train_set)) {
399
        data_idx <- 1L
400
      } else {
401

402
        # Check for validation data
403
        if (length(private$valid_sets) > 0L) {
404

405
          # Loop through each validation set
406
          for (i in seq_along(private$valid_sets)) {
407

408
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
409
            if (identical(data, private$valid_sets[[i]])) {
410

411
              # Found identical data, skip
412
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
413
              break
414

Guolin Ke's avatar
Guolin Ke committed
415
            }
416

Guolin Ke's avatar
Guolin Ke committed
417
          }
418

Guolin Ke's avatar
Guolin Ke committed
419
        }
420

Guolin Ke's avatar
Guolin Ke committed
421
      }
422

423
      # Check if evaluation was not done
424
      if (data_idx == 0L) {
425

426
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
427
428
        self$add_valid(data, name)
        data_idx <- private$num_dataset
429

Guolin Ke's avatar
Guolin Ke committed
430
      }
431

432
      # Evaluate data
433
434
435
436
437
438
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
439
      )
440

Guolin Ke's avatar
Guolin Ke committed
441
    },
442

443
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
444
    eval_train = function(feval = NULL) {
445
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
446
    },
447

448
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
449
    eval_valid = function(feval = NULL) {
450

451
      # Create ret list
452
      ret <- list()
453

454
      # Check if validation is empty
455
      if (length(private$valid_sets) <= 0L) {
456
457
        return(ret)
      }
458

459
      # Loop through each validation set
460
      for (i in seq_along(private$valid_sets)) {
461
462
        ret <- append(
          x = ret
463
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
464
        )
Guolin Ke's avatar
Guolin Ke committed
465
      }
466

467
      return(ret)
468

Guolin Ke's avatar
Guolin Ke committed
469
    },
470

471
    # Save model
472
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
473

474
475
476
477
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
478

479
      # Save booster model
480
481
482
      call_state <- 0L
      .Call(
        LGBM_BoosterSaveModel_R
483
484
        , private$handle
        , as.integer(num_iteration)
485
        , as.integer(feature_importance_type)
486
        , lgb.c_str(x = filename)
487
        , call_state
488
      )
489

490
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
491
    },
492

493
    # Save model to string
494
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
495

496
497
498
499
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
500

501
502
503
504
505
506
507
508
509
      # Create buffer
      buf_len <- as.integer(1024L * 1024L)
      act_len <- 0L
      buf <- raw(buf_len)

      # Call buffer
      call_state <- 0L
      .Call(
          LGBM_BoosterSaveModelToString_R
510
511
512
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
          , buf_len
          , act_len
          , buf
          , call_state
      )

      # Check for buffer content
      if (act_len > buf_len) {
        buf_len <- act_len
        buf <- raw(buf_len)
        call_state <- 0L
        .Call(
          LGBM_BoosterSaveModelToString_R
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
          , buf_len
          , act_len
          , buf
          , call_state
533
        )
534
535
536
537
      }

      return(
        lgb.encode.char(arr = buf, len = act_len)
538
      )
539

540
    },
541

542
    # Dump model in memory
543
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
544

545
546
547
548
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
549

550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
      buf_len <- as.integer(1024L * 1024L)
      act_len <- 0L
      buf <- raw(buf_len)
      call_state <- 0L
      .Call(
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
        , buf_len
        , act_len
        , buf
        , call_state
      )

      if (act_len > buf_len) {
        buf_len <- act_len
        buf <- raw(buf_len)
        call_state <- 0L
        .Call(
          LGBM_BoosterDumpModel_R
571
572
573
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
574
575
576
577
          , buf_len
          , act_len
          , buf
          , call_state
578
        )
579
580
581
582
      }

      return(
        lgb.encode.char(arr = buf, len = act_len)
583
      )
584

Guolin Ke's avatar
Guolin Ke committed
585
    },
586

587
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
588
    predict = function(data,
589
                       start_iteration = NULL,
590
591
592
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
593
                       predcontrib = FALSE,
594
                       header = FALSE,
595
                       reshape = FALSE, ...) {
596

597
      # Check if number of iteration is non existent
598
599
600
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
601
      # Check if start iteration is non existent
602
603
604
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
605

606
      # Predict on new data
607
      predictor <- Predictor$new(private$handle, ...)
608
609
      return(
        predictor$predict(
610
611
612
613
614
615
616
617
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
          , reshape = reshape
618
        )
619
      )
620

621
    },
622

623
624
    # Transform into predictor
    to_predictor = function() {
625
      return(Predictor$new(private$handle))
Guolin Ke's avatar
Guolin Ke committed
626
    },
627

628
    # Used for save
629
    raw = NA,
630

631
    # Save model to temporary file for in-memory saving
632
    save = function() {
633

634
      # Overwrite model in object
635
      self$raw <- self$save_model_to_string(NULL)
636

637
638
      return(invisible(NULL))

639
    }
640

Guolin Ke's avatar
Guolin Ke committed
641
642
  ),
  private = list(
643
644
645
646
647
648
649
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
650
651
    num_class = 1L,
    num_dataset = 0L,
652
653
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
654
    higher_better_inner_eval = NULL,
655
    set_objective_to_none = FALSE,
656
    train_set_version = 0L,
657
658
    # Predict data
    inner_predict = function(idx) {
659

660
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
661
      data_name <- private$name_train_set
662

663
      # Check for id bigger than 1
664
665
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
666
      }
667

668
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
669
670
671
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
672

673
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
674
      if (is.null(private$predict_buffer[[data_name]])) {
675

676
        # Store predictions
677
        call_state <- 0L
678
        npred <- 0L
679
680
        .Call(
          LGBM_BoosterGetNumPredict_R
681
          , private$handle
682
          , as.integer(idx - 1L)
683
684
          , npred
          , call_state
685
        )
686
        private$predict_buffer[[data_name]] <- numeric(npred)
687

Guolin Ke's avatar
Guolin Ke committed
688
      }
689

690
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
691
      if (!private$is_predicted_cur_iter[[idx]]) {
692

693
        # Use buffer
694
695
696
        call_state <- 0L
        .Call(
          LGBM_BoosterGetPredict_R
697
          , private$handle
698
          , as.integer(idx - 1L)
699
700
          , private$predict_buffer[[data_name]]
          , call_state
701
        )
Guolin Ke's avatar
Guolin Ke committed
702
703
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
704

705
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
706
    },
707

708
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
709
    get_eval_info = function() {
710

711
      # Check for evaluation names emptiness
Guolin Ke's avatar
Guolin Ke committed
712
      if (is.null(private$eval_names)) {
713

714
        # Get evaluation names
715
716
717
718
719
720
        buf_len <- as.integer(1024L * 1024L)
        act_len <- 0L
        buf <- raw(buf_len)
        call_state <- 0L
        .Call(
          LGBM_BoosterGetEvalNames_R
721
          , private$handle
722
723
724
725
          , buf_len
          , act_len
          , buf
          , call_state
726
        )
727
728
729
730
731
732
733
734
735
736
737
738
739
740
        if (act_len > buf_len) {
          buf_len <- act_len
          buf <- raw(buf_len)
          call_state <- 0L
          .Call(
            LGBM_BoosterGetEvalNames_R
            , private$handle
            , buf_len
            , act_len
            , buf
            , call_state
          )
        }
        names <- lgb.encode.char(arr = buf, len = act_len)
741

742
        # Check names' length
743
        if (nchar(names) > 0L) {
744

745
          # Parse and store privately names
746
          names <- strsplit(names, "\t")[[1L]]
Guolin Ke's avatar
Guolin Ke committed
747
          private$eval_names <- names
748
749
750
751
752

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
          metric_names <- gsub("@.*", "", names)
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
753

Guolin Ke's avatar
Guolin Ke committed
754
        }
755

Guolin Ke's avatar
Guolin Ke committed
756
      }
757

758
      return(private$eval_names)
759

Guolin Ke's avatar
Guolin Ke committed
760
    },
761

762
    # Perform inner evaluation
Guolin Ke's avatar
Guolin Ke committed
763
    inner_eval = function(data_name, data_idx, feval = NULL) {
764

765
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
766
767
768
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
769

770
      # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
771
      private$get_eval_info()
772

773
      # Prepare return
Guolin Ke's avatar
Guolin Ke committed
774
      ret <- list()
775

776
      # Check evaluation names existence
777
      if (length(private$eval_names) > 0L) {
778

779
780
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
781
782
783
        call_state <- 0L
        .Call(
          LGBM_BoosterGetEval_R
784
          , private$handle
785
          , as.integer(data_idx - 1L)
786
787
          , tmp_vals
          , call_state
788
        )
789

790
        # Loop through all evaluation names
791
        for (i in seq_along(private$eval_names)) {
792

793
794
795
796
797
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
798
          res$higher_better <- private$higher_better_inner_eval[i]
799
          ret <- append(ret, list(res))
800

Guolin Ke's avatar
Guolin Ke committed
801
        }
802

Guolin Ke's avatar
Guolin Ke committed
803
      }
804

805
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
806
      if (!is.null(feval)) {
807

808
        # Check if evaluation metric is a function
809
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
810
811
          stop("lgb.Booster.eval: feval should be a function")
        }
812

813
        # Prepare data
Guolin Ke's avatar
Guolin Ke committed
814
        data <- private$train_set
815

816
        # Check if data to assess is existing differently
817
818
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
819
        }
820

821
        # Perform function evaluation
822
        res <- feval(private$inner_predict(data_idx), data)
823

824
        # Check for name correctness
825
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
826
          stop("lgb.Booster.eval: custom eval function should return a
827
828
            list with attribute (name, value, higher_better)");
        }
829

830
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
831
        res$data_name <- data_name
832
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
833
      }
834

835
      return(ret)
836

Guolin Ke's avatar
Guolin Ke committed
837
    }
838

Guolin Ke's avatar
Guolin Ke committed
839
840
841
  )
)

842
843
844
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
845
846
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
847
848
849
850
851
852
853
854
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
855
#' @param rawscore whether the prediction should be returned in the for of original untransformed
856
857
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
858
#' @param predleaf whether predict leaf index instead.
859
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
860
#' @param header only used for prediction for text file. True if text file has header
861
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
862
#'                prediction outputs per case.
James Lamb's avatar
James Lamb committed
863
864
#' @param ... Additional named arguments passed to the \code{predict()} method of
#'            the \code{lgb.Booster} object passed to \code{object}.
865
866
867
868
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
869
#'
870
871
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
872
#'
Guolin Ke's avatar
Guolin Ke committed
873
#' @examples
874
#' \donttest{
875
876
877
878
879
880
881
882
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
883
884
885
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
886
#'   , nrounds = 5L
887
#'   , valids = valids
888
889
#'   , min_data = 1L
#'   , learning_rate = 1.0
890
#' )
891
#' preds <- predict(model, test$data)
892
#' }
Guolin Ke's avatar
Guolin Ke committed
893
#' @export
James Lamb's avatar
James Lamb committed
894
895
predict.lgb.Booster <- function(object,
                                data,
896
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
897
898
899
900
901
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
902
                                reshape = FALSE,
James Lamb's avatar
James Lamb committed
903
                                ...) {
904

905
  if (!lgb.is.Booster(x = object)) {
906
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
907
  }
908

909
  # Return booster predictions
910
911
912
  return(
    object$predict(
      data = data
913
914
915
916
917
918
919
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
      , reshape = reshape
920
921
      , ...
    )
922
  )
Guolin Ke's avatar
Guolin Ke committed
923
924
}

925
926
927
928
#' @name lgb.load
#' @title Load LightGBM model
#' @description  Load LightGBM takes in either a file path or model string.
#'               If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
929
#' @param filename path of model file
930
#' @param model_str a str containing the model
931
#'
932
#' @return lgb.Booster
933
#'
Guolin Ke's avatar
Guolin Ke committed
934
#' @examples
935
#' \donttest{
936
937
938
939
940
941
942
943
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
944
945
946
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
947
#'   , nrounds = 5L
948
#'   , valids = valids
949
950
#'   , min_data = 1L
#'   , learning_rate = 1.0
951
#'   , early_stopping_rounds = 3L
952
#' )
953
954
955
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
956
957
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
958
#' }
Guolin Ke's avatar
Guolin Ke committed
959
#' @export
960
lgb.load <- function(filename = NULL, model_str = NULL) {
961

962
963
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
964

965
966
967
968
969
970
971
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
972
973
    return(invisible(Booster$new(modelfile = filename)))
  }
974

975
976
977
978
  if (model_str_provided) {
    if (!is.character(model_str)) {
      stop("lgb.load: model_str should be character")
    }
979
980
    return(invisible(Booster$new(model_str = model_str)))
  }
981

982
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
983
984
}

985
986
987
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
988
989
990
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
991
#'
992
#' @return lgb.Booster
993
#'
Guolin Ke's avatar
Guolin Ke committed
994
#' @examples
995
#' \donttest{
996
997
998
999
1000
1001
1002
1003
1004
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
1005
1006
1007
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1008
#'   , nrounds = 10L
1009
#'   , valids = valids
1010
1011
1012
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
1013
#' )
1014
#' lgb.save(model, tempfile(fileext = ".txt"))
1015
#' }
Guolin Ke's avatar
Guolin Ke committed
1016
#' @export
1017
lgb.save <- function(booster, filename, num_iteration = NULL) {
1018

1019
  if (!lgb.is.Booster(x = booster)) {
1020
1021
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1022

1023
1024
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
1025
  }
1026

1027
  # Store booster
1028
1029
1030
1031
1032
1033
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
1034

Guolin Ke's avatar
Guolin Ke committed
1035
1036
}

1037
1038
1039
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
1040
1041
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
1042
#'
Guolin Ke's avatar
Guolin Ke committed
1043
#' @return json format of model
1044
#'
Guolin Ke's avatar
Guolin Ke committed
1045
#' @examples
1046
#' \donttest{
1047
1048
1049
1050
1051
1052
1053
1054
1055
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
1056
1057
1058
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1059
#'   , nrounds = 10L
1060
#'   , valids = valids
1061
1062
1063
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
1064
#' )
1065
#' json_model <- lgb.dump(model)
1066
#' }
Guolin Ke's avatar
Guolin Ke committed
1067
#' @export
1068
lgb.dump <- function(booster, num_iteration = NULL) {
1069

1070
  if (!lgb.is.Booster(x = booster)) {
1071
1072
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1073

1074
  # Return booster at requested iteration
1075
  return(booster$dump_model(num_iteration =  num_iteration))
1076

Guolin Ke's avatar
Guolin Ke committed
1077
1078
}

1079
1080
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
1081
1082
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
1083
#' @param booster Object of class \code{lgb.Booster}
1084
1085
1086
1087
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
1088
#' @param is_err TRUE will return evaluation error instead
1089
#'
1090
#' @return numeric vector of evaluation result
1091
#'
1092
#' @examples
1093
#' \donttest{
1094
#' # train a regression model
1095
1096
1097
1098
1099
1100
1101
1102
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
1103
1104
1105
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1106
#'   , nrounds = 5L
1107
#'   , valids = valids
1108
1109
#'   , min_data = 1L
#'   , learning_rate = 1.0
1110
#' )
1111
1112
1113
1114
1115
1116
1117
1118
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1119
#' lgb.get.eval.result(model, "test", "l2")
1120
#' }
Guolin Ke's avatar
Guolin Ke committed
1121
#' @export
1122
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1123

1124
  # Check if booster is booster
1125
  if (!lgb.is.Booster(x = booster)) {
1126
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1127
  }
1128

1129
  # Check if data and evaluation name are characters or not
1130
1131
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1132
  }
1133

1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1144
  }
1145

1146
  # Check if evaluation result is existing
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1158
1159
    stop("lgb.get.eval.result: wrong eval name")
  }
1160

1161
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1162

1163
  # Check if error is requested
1164
  if (is_err) {
1165
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1166
  }
1167

1168
  # Check if iteration is non existant
1169
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1170
1171
    return(as.numeric(result))
  }
1172

1173
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1174
  iters <- as.integer(iters)
1175
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1176
  iters <- iters - delta
1177

1178
  # Return requested result
1179
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1180
}