lgb.Booster.R 31.3 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
#' @importFrom R6 R6Class
Booster <- R6::R6Class(
3
  classname = "lgb.Booster",
4
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
5
  public = list(
6

7
    best_iter = -1L,
8
    best_score = NA_real_,
9
    params = list(),
Guolin Ke's avatar
Guolin Ke committed
10
    record_evals = list(),
11

12
13
    # Finalize will free up the handles
    finalize = function() {
14

15
      # Check the need for freeing handle
16
      if (!lgb.is.null.handle(x = private$handle)) {
17

18
        # Freeing up handle
19
20
21
22
        .Call(
          LGBM_BoosterFree_R
          , private$handle
        )
Guolin Ke's avatar
Guolin Ke committed
23
        private$handle <- NULL
24

Guolin Ke's avatar
Guolin Ke committed
25
      }
26

27
28
      return(invisible(NULL))

29
    },
30

31
32
    # Initialize will create a starter booster
    initialize = function(params = list(),
Guolin Ke's avatar
Guolin Ke committed
33
34
                          train_set = NULL,
                          modelfile = NULL,
35
                          model_str = NULL,
Guolin Ke's avatar
Guolin Ke committed
36
                          ...) {
37

38
39
      # Create parameters and handle
      params <- append(params, list(...))
Guolin Ke's avatar
Guolin Ke committed
40
      handle <- lgb.null.handle()
41

42
43
      # Attempts to create a handle for the dataset
      try({
44

45
46
47
        # Check if training dataset is not null
        if (!is.null(train_set)) {
          # Check if training dataset is lgb.Dataset or not
48
          if (!lgb.check.r6.class(object = train_set, name = "lgb.Dataset")) {
49
50
            stop("lgb.Booster: Can only use lgb.Dataset as training data")
          }
51
52
          train_set_handle <- train_set$.__enclos_env__$private$get_handle()
          params <- modifyList(params, train_set$get_params())
53
          params_str <- lgb.params2str(params = params)
54
          # Store booster handle
55
56
          .Call(
            LGBM_BoosterCreate_R
57
            , train_set_handle
58
            , params_str
59
            , handle
60
          )
61

62
63
          # Create private booster information
          private$train_set <- train_set
64
          private$train_set_version <- train_set$.__enclos_env__$private$version
65
          private$num_dataset <- 1L
66
          private$init_predictor <- train_set$.__enclos_env__$private$predictor
67

68
69
          # Check if predictor is existing
          if (!is.null(private$init_predictor)) {
70

71
            # Merge booster
72
73
            .Call(
              LGBM_BoosterMerge_R
74
75
76
              , handle
              , private$init_predictor$.__enclos_env__$private$handle
            )
77

78
          }
79

80
81
          # Check current iteration
          private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
82

83
        } else if (!is.null(modelfile)) {
84

85
86
87
88
          # Do we have a model file as character?
          if (!is.character(modelfile)) {
            stop("lgb.Booster: Can only use a string as model file path")
          }
89

90
          # Create booster from model
91
92
          .Call(
            LGBM_BoosterCreateFromModelfile_R
93
            , lgb.c_str(x = modelfile)
94
            , handle
95
          )
96

97
        } else if (!is.null(model_str)) {
98

99
          # Do we have a model_str as character?
100
101
102
          if (!is.character(model_str)) {
            stop("lgb.Booster: Can only use a string as model_str")
          }
103

104
          # Create booster from model
105
106
          .Call(
            LGBM_BoosterLoadModelFromString_R
107
            , lgb.c_str(x = model_str)
108
            , handle
109
          )
110

111
        } else {
112

113
          # Booster non existent
114
115
116
117
          stop(
            "lgb.Booster: Need at least either training dataset, "
            , "model file, or model_str to create booster instance"
          )
118

119
        }
120

121
      })
122

123
      # Check whether the handle was created properly if it was not stopped earlier by a stop call
124
      if (isTRUE(lgb.is.null.handle(x = handle))) {
125

Guolin Ke's avatar
Guolin Ke committed
126
        stop("lgb.Booster: cannot create Booster handle")
127

Guolin Ke's avatar
Guolin Ke committed
128
      } else {
129

Guolin Ke's avatar
Guolin Ke committed
130
131
132
133
        # Create class
        class(handle) <- "lgb.Booster.handle"
        private$handle <- handle
        private$num_class <- 1L
134
135
        .Call(
          LGBM_BoosterGetNumClasses_R
136
          , private$handle
137
          , private$num_class
138
        )
139

Guolin Ke's avatar
Guolin Ke committed
140
      }
141

142
143
      self$params <- params

144
145
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
146
    },
147

148
    # Set training data name
Guolin Ke's avatar
Guolin Ke committed
149
    set_train_data_name = function(name) {
150

151
      # Set name
Guolin Ke's avatar
Guolin Ke committed
152
      private$name_train_set <- name
153
      return(invisible(self))
154

Guolin Ke's avatar
Guolin Ke committed
155
    },
156

157
    # Add validation data
Guolin Ke's avatar
Guolin Ke committed
158
    add_valid = function(data, name) {
159

160
      # Check if data is lgb.Dataset
161
      if (!lgb.check.r6.class(object = data, name = "lgb.Dataset")) {
162
        stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
Guolin Ke's avatar
Guolin Ke committed
163
      }
164

165
      # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
166
      if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
167
168
169
170
        stop(
          "lgb.Booster.add_valid: Failed to add validation data; "
          , "you should use the same predictor for these data"
        )
Guolin Ke's avatar
Guolin Ke committed
171
      }
172

173
      # Check if names are character
174
175
      if (!is.character(name)) {
        stop("lgb.Booster.add_valid: Can only use characters as data name")
Guolin Ke's avatar
Guolin Ke committed
176
      }
177

178
      # Add validation data to booster
179
180
      .Call(
        LGBM_BoosterAddValidData_R
181
182
183
        , private$handle
        , data$.__enclos_env__$private$get_handle()
      )
184

185
186
187
      # Store private information
      private$valid_sets <- c(private$valid_sets, data)
      private$name_valid_sets <- c(private$name_valid_sets, name)
188
      private$num_dataset <- private$num_dataset + 1L
189
      private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
190

191
      return(invisible(self))
192

Guolin Ke's avatar
Guolin Ke committed
193
    },
194

195
    # Reset parameters of booster
Guolin Ke's avatar
Guolin Ke committed
196
    reset_parameter = function(params, ...) {
197

198
199
200
201
202
      if (methods::is(self$params, "list")) {
        params <- modifyList(self$params, params)
      }

      params <- modifyList(params, list(...))
203
      params_str <- lgb.params2str(params = params)
204

205
206
      .Call(
        LGBM_BoosterResetParameter_R
207
208
209
        , private$handle
        , params_str
      )
210
      self$params <- params
211

212
      return(invisible(self))
213

Guolin Ke's avatar
Guolin Ke committed
214
    },
215

216
    # Perform boosting update iteration
Guolin Ke's avatar
Guolin Ke committed
217
    update = function(train_set = NULL, fobj = NULL) {
218

219
220
221
222
223
224
      if (is.null(train_set)) {
        if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
          train_set <- private$train_set
        }
      }

225
      # Check if training set is not null
Guolin Ke's avatar
Guolin Ke committed
226
      if (!is.null(train_set)) {
227

228
        # Check if training set is lgb.Dataset
229
        if (!lgb.check.r6.class(object = train_set, name = "lgb.Dataset")) {
Guolin Ke's avatar
Guolin Ke committed
230
231
          stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
        }
232

233
        # Check if predictors are identical
Guolin Ke's avatar
Guolin Ke committed
234
        if (!identical(train_set$predictor, private$init_predictor)) {
235
          stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
Guolin Ke's avatar
Guolin Ke committed
236
        }
237

238
        # Reset training data on booster
239
240
        .Call(
          LGBM_BoosterResetTrainingData_R
241
242
243
          , private$handle
          , train_set$.__enclos_env__$private$get_handle()
        )
244

245
        # Store private train set
246
        private$train_set <- train_set
247
        private$train_set_version <- train_set$.__enclos_env__$private$version
248

Guolin Ke's avatar
Guolin Ke committed
249
      }
250

251
      # Check if objective is empty
Guolin Ke's avatar
Guolin Ke committed
252
      if (is.null(fobj)) {
253
254
255
        if (private$set_objective_to_none) {
          stop("lgb.Booster.update: cannot update due to null objective function")
        }
256
        # Boost iteration from known objective
257
258
        .Call(
          LGBM_BoosterUpdateOneIter_R
259
260
          , private$handle
        )
261

Guolin Ke's avatar
Guolin Ke committed
262
      } else {
263

264
265
266
267
        # Check if objective is function
        if (!is.function(fobj)) {
          stop("lgb.Booster.update: fobj should be a function")
        }
268
        if (!private$set_objective_to_none) {
269
          self$reset_parameter(params = list(objective = "none"))
270
          private$set_objective_to_none <- TRUE
271
        }
272
        # Perform objective calculation
273
        gpair <- fobj(private$inner_predict(1L), private$train_set)
274

275
        # Check for gradient and hessian as list
276
        if (is.null(gpair$grad) || is.null(gpair$hess)) {
277
          stop("lgb.Booster.update: custom objective should
278
279
            return a list with attributes (hess, grad)")
        }
280

281
        # Return custom boosting gradient/hessian
282
283
        .Call(
          LGBM_BoosterUpdateOneIterCustom_R
284
285
286
287
288
          , private$handle
          , gpair$grad
          , gpair$hess
          , length(gpair$grad)
        )
289

Guolin Ke's avatar
Guolin Ke committed
290
      }
291

292
      # Loop through each iteration
293
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
294
295
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
296

297
      return(invisible(self))
298

Guolin Ke's avatar
Guolin Ke committed
299
    },
300

301
    # Return one iteration behind
Guolin Ke's avatar
Guolin Ke committed
302
    rollback_one_iter = function() {
303

304
      # Return one iteration behind
305
306
      .Call(
        LGBM_BoosterRollbackOneIter_R
307
308
        , private$handle
      )
309

310
      # Loop through each iteration
311
      for (i in seq_along(private$is_predicted_cur_iter)) {
Guolin Ke's avatar
Guolin Ke committed
312
313
        private$is_predicted_cur_iter[[i]] <- FALSE
      }
314

315
      return(invisible(self))
316

Guolin Ke's avatar
Guolin Ke committed
317
    },
318

319
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
320
    current_iter = function() {
321

322
      cur_iter <- 0L
323
324
325
326
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
327
      )
328
      return(cur_iter)
329

Guolin Ke's avatar
Guolin Ke committed
330
    },
331

332
    # Get upper bound
333
    upper_bound = function() {
334

335
      upper_bound <- 0.0
336
337
338
339
      .Call(
        LGBM_BoosterGetUpperBoundValue_R
        , private$handle
        , upper_bound
340
      )
341
      return(upper_bound)
342
343
344
345

    },

    # Get lower bound
346
    lower_bound = function() {
347

348
      lower_bound <- 0.0
349
350
351
352
      .Call(
        LGBM_BoosterGetLowerBoundValue_R
        , private$handle
        , lower_bound
353
      )
354
      return(lower_bound)
355
356
357

    },

358
    # Evaluate data on metrics
Guolin Ke's avatar
Guolin Ke committed
359
    eval = function(data, name, feval = NULL) {
360

361
      # Check if dataset is lgb.Dataset
362
      if (!lgb.check.r6.class(object = data, name = "lgb.Dataset")) {
363
        stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
Guolin Ke's avatar
Guolin Ke committed
364
      }
365

366
      # Check for identical data
367
      data_idx <- 0L
368
      if (identical(data, private$train_set)) {
369
        data_idx <- 1L
370
      } else {
371

372
        # Check for validation data
373
        if (length(private$valid_sets) > 0L) {
374

375
          # Loop through each validation set
376
          for (i in seq_along(private$valid_sets)) {
377

378
            # Check for identical validation data with training data
Guolin Ke's avatar
Guolin Ke committed
379
            if (identical(data, private$valid_sets[[i]])) {
380

381
              # Found identical data, skip
382
              data_idx <- i + 1L
Guolin Ke's avatar
Guolin Ke committed
383
              break
384

Guolin Ke's avatar
Guolin Ke committed
385
            }
386

Guolin Ke's avatar
Guolin Ke committed
387
          }
388

Guolin Ke's avatar
Guolin Ke committed
389
        }
390

Guolin Ke's avatar
Guolin Ke committed
391
      }
392

393
      # Check if evaluation was not done
394
      if (data_idx == 0L) {
395

396
        # Add validation data by name
Guolin Ke's avatar
Guolin Ke committed
397
398
        self$add_valid(data, name)
        data_idx <- private$num_dataset
399

Guolin Ke's avatar
Guolin Ke committed
400
      }
401

402
      # Evaluate data
403
404
405
406
407
408
      return(
        private$inner_eval(
          data_name = name
          , data_idx = data_idx
          , feval = feval
        )
409
      )
410

Guolin Ke's avatar
Guolin Ke committed
411
    },
412

413
    # Evaluation training data
Guolin Ke's avatar
Guolin Ke committed
414
    eval_train = function(feval = NULL) {
415
      return(private$inner_eval(private$name_train_set, 1L, feval))
Guolin Ke's avatar
Guolin Ke committed
416
    },
417

418
    # Evaluation validation data
Guolin Ke's avatar
Guolin Ke committed
419
    eval_valid = function(feval = NULL) {
420

421
      # Create ret list
422
      ret <- list()
423

424
      # Check if validation is empty
425
      if (length(private$valid_sets) <= 0L) {
426
427
        return(ret)
      }
428

429
      # Loop through each validation set
430
      for (i in seq_along(private$valid_sets)) {
431
432
        ret <- append(
          x = ret
433
          , values = private$inner_eval(private$name_valid_sets[[i]], i + 1L, feval)
434
        )
Guolin Ke's avatar
Guolin Ke committed
435
      }
436

437
      return(ret)
438

Guolin Ke's avatar
Guolin Ke committed
439
    },
440

441
    # Save model
442
    save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
443

444
445
446
447
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
448

449
      # Save booster model
450
451
      .Call(
        LGBM_BoosterSaveModel_R
452
453
        , private$handle
        , as.integer(num_iteration)
454
        , as.integer(feature_importance_type)
455
        , lgb.c_str(x = filename)
456
      )
457

458
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
459
    },
460

461
    # Save model to string
462
    save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
463

464
465
466
467
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
468

469
470
471
472
473
474
475
476
      # Create buffer
      buf_len <- as.integer(1024L * 1024L)
      act_len <- 0L
      buf <- raw(buf_len)

      # Call buffer
      .Call(
          LGBM_BoosterSaveModelToString_R
477
478
479
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
          , buf_len
          , act_len
          , buf
      )

      # Check for buffer content
      if (act_len > buf_len) {
        buf_len <- act_len
        buf <- raw(buf_len)
        .Call(
          LGBM_BoosterSaveModelToString_R
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
          , buf_len
          , act_len
          , buf
497
        )
498
499
500
501
      }

      return(
        lgb.encode.char(arr = buf, len = act_len)
502
      )
503

504
    },
505

506
    # Dump model in memory
507
    dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
508

509
510
511
512
      # Check if number of iteration is non existent
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
513

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
      buf_len <- as.integer(1024L * 1024L)
      act_len <- 0L
      buf <- raw(buf_len)
      .Call(
        LGBM_BoosterDumpModel_R
        , private$handle
        , as.integer(num_iteration)
        , as.integer(feature_importance_type)
        , buf_len
        , act_len
        , buf
      )

      if (act_len > buf_len) {
        buf_len <- act_len
        buf <- raw(buf_len)
        .Call(
          LGBM_BoosterDumpModel_R
532
533
534
          , private$handle
          , as.integer(num_iteration)
          , as.integer(feature_importance_type)
535
536
537
          , buf_len
          , act_len
          , buf
538
        )
539
540
541
542
      }

      return(
        lgb.encode.char(arr = buf, len = act_len)
543
      )
544

Guolin Ke's avatar
Guolin Ke committed
545
    },
546

547
    # Predict on new data
Guolin Ke's avatar
Guolin Ke committed
548
    predict = function(data,
549
                       start_iteration = NULL,
550
551
552
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
553
                       predcontrib = FALSE,
554
                       header = FALSE,
555
                       reshape = FALSE, ...) {
556

557
      # Check if number of iteration is non existent
558
559
560
      if (is.null(num_iteration)) {
        num_iteration <- self$best_iter
      }
561
      # Check if start iteration is non existent
562
563
564
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
565

566
      # Predict on new data
567
      predictor <- Predictor$new(private$handle, ...)
568
569
      return(
        predictor$predict(
570
571
572
573
574
575
576
577
          data = data
          , start_iteration = start_iteration
          , num_iteration = num_iteration
          , rawscore = rawscore
          , predleaf = predleaf
          , predcontrib = predcontrib
          , header = header
          , reshape = reshape
578
        )
579
      )
580

581
    },
582

583
584
    # Transform into predictor
    to_predictor = function() {
585
      return(Predictor$new(private$handle))
Guolin Ke's avatar
Guolin Ke committed
586
    },
587

588
    # Used for save
589
    raw = NA,
590

591
    # Save model to temporary file for in-memory saving
592
    save = function() {
593

594
      # Overwrite model in object
595
      self$raw <- self$save_model_to_string(NULL)
596

597
598
      return(invisible(NULL))

599
    }
600

Guolin Ke's avatar
Guolin Ke committed
601
602
  ),
  private = list(
603
604
605
606
607
608
609
    handle = NULL,
    train_set = NULL,
    name_train_set = "training",
    valid_sets = list(),
    name_valid_sets = list(),
    predict_buffer = list(),
    is_predicted_cur_iter = list(),
610
611
    num_class = 1L,
    num_dataset = 0L,
612
613
    init_predictor = NULL,
    eval_names = NULL,
Guolin Ke's avatar
Guolin Ke committed
614
    higher_better_inner_eval = NULL,
615
    set_objective_to_none = FALSE,
616
    train_set_version = 0L,
617
618
    # Predict data
    inner_predict = function(idx) {
619

620
      # Store data name
Guolin Ke's avatar
Guolin Ke committed
621
      data_name <- private$name_train_set
622

623
      # Check for id bigger than 1
624
625
      if (idx > 1L) {
        data_name <- private$name_valid_sets[[idx - 1L]]
626
      }
627

628
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
629
630
631
      if (idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
632

633
      # Check for prediction buffer
Guolin Ke's avatar
Guolin Ke committed
634
      if (is.null(private$predict_buffer[[data_name]])) {
635

636
        # Store predictions
637
        npred <- 0L
638
639
        .Call(
          LGBM_BoosterGetNumPredict_R
640
          , private$handle
641
          , as.integer(idx - 1L)
642
          , npred
643
        )
644
        private$predict_buffer[[data_name]] <- numeric(npred)
645

Guolin Ke's avatar
Guolin Ke committed
646
      }
647

648
      # Check if current iteration was already predicted
Guolin Ke's avatar
Guolin Ke committed
649
      if (!private$is_predicted_cur_iter[[idx]]) {
650

651
        # Use buffer
652
653
        .Call(
          LGBM_BoosterGetPredict_R
654
          , private$handle
655
          , as.integer(idx - 1L)
656
          , private$predict_buffer[[data_name]]
657
        )
Guolin Ke's avatar
Guolin Ke committed
658
659
        private$is_predicted_cur_iter[[idx]] <- TRUE
      }
660

661
      return(private$predict_buffer[[data_name]])
Guolin Ke's avatar
Guolin Ke committed
662
    },
663

664
    # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
665
    get_eval_info = function() {
666

667
      # Check for evaluation names emptiness
Guolin Ke's avatar
Guolin Ke committed
668
      if (is.null(private$eval_names)) {
669

670
        # Get evaluation names
671
672
673
674
675
        buf_len <- as.integer(1024L * 1024L)
        act_len <- 0L
        buf <- raw(buf_len)
        .Call(
          LGBM_BoosterGetEvalNames_R
676
          , private$handle
677
678
679
          , buf_len
          , act_len
          , buf
680
        )
681
682
683
684
685
686
687
688
689
690
691
692
        if (act_len > buf_len) {
          buf_len <- act_len
          buf <- raw(buf_len)
          .Call(
            LGBM_BoosterGetEvalNames_R
            , private$handle
            , buf_len
            , act_len
            , buf
          )
        }
        names <- lgb.encode.char(arr = buf, len = act_len)
693

694
        # Check names' length
695
        if (nchar(names) > 0L) {
696

697
          # Parse and store privately names
698
          names <- strsplit(names, "\t")[[1L]]
Guolin Ke's avatar
Guolin Ke committed
699
          private$eval_names <- names
700
701
702
703
704

          # some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
          # ndcg metric evaluated at the first "query result" in learning-to-rank
          metric_names <- gsub("@.*", "", names)
          private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
705

Guolin Ke's avatar
Guolin Ke committed
706
        }
707

Guolin Ke's avatar
Guolin Ke committed
708
      }
709

710
      return(private$eval_names)
711

Guolin Ke's avatar
Guolin Ke committed
712
    },
713

714
    # Perform inner evaluation
Guolin Ke's avatar
Guolin Ke committed
715
    inner_eval = function(data_name, data_idx, feval = NULL) {
716

717
      # Check for unknown dataset (over the maximum provided range)
Guolin Ke's avatar
Guolin Ke committed
718
719
720
      if (data_idx > private$num_dataset) {
        stop("data_idx should not be greater than num_dataset")
      }
721

722
      # Get evaluation information
Guolin Ke's avatar
Guolin Ke committed
723
      private$get_eval_info()
724

725
      # Prepare return
Guolin Ke's avatar
Guolin Ke committed
726
      ret <- list()
727

728
      # Check evaluation names existence
729
      if (length(private$eval_names) > 0L) {
730

731
732
        # Create evaluation values
        tmp_vals <- numeric(length(private$eval_names))
733
734
        .Call(
          LGBM_BoosterGetEval_R
735
          , private$handle
736
          , as.integer(data_idx - 1L)
737
          , tmp_vals
738
        )
739

740
        # Loop through all evaluation names
741
        for (i in seq_along(private$eval_names)) {
742

743
744
745
746
747
          # Store evaluation and append to return
          res <- list()
          res$data_name <- data_name
          res$name <- private$eval_names[i]
          res$value <- tmp_vals[i]
Guolin Ke's avatar
Guolin Ke committed
748
          res$higher_better <- private$higher_better_inner_eval[i]
749
          ret <- append(ret, list(res))
750

Guolin Ke's avatar
Guolin Ke committed
751
        }
752

Guolin Ke's avatar
Guolin Ke committed
753
      }
754

755
      # Check if there are evaluation metrics
Guolin Ke's avatar
Guolin Ke committed
756
      if (!is.null(feval)) {
757

758
        # Check if evaluation metric is a function
759
        if (!is.function(feval)) {
Guolin Ke's avatar
Guolin Ke committed
760
761
          stop("lgb.Booster.eval: feval should be a function")
        }
762

763
        # Prepare data
Guolin Ke's avatar
Guolin Ke committed
764
        data <- private$train_set
765

766
        # Check if data to assess is existing differently
767
768
        if (data_idx > 1L) {
          data <- private$valid_sets[[data_idx - 1L]]
769
        }
770

771
        # Perform function evaluation
772
        res <- feval(private$inner_predict(data_idx), data)
773

774
        # Check for name correctness
775
        if (is.null(res$name) || is.null(res$value) ||  is.null(res$higher_better)) {
776
          stop("lgb.Booster.eval: custom eval function should return a
777
778
            list with attribute (name, value, higher_better)");
        }
779

780
        # Append names and evaluation
Guolin Ke's avatar
Guolin Ke committed
781
        res$data_name <- data_name
782
        ret <- append(ret, list(res))
Guolin Ke's avatar
Guolin Ke committed
783
      }
784

785
      return(ret)
786

Guolin Ke's avatar
Guolin Ke committed
787
    }
788

Guolin Ke's avatar
Guolin Ke committed
789
790
791
  )
)

792
793
794
#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
795
796
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
797
798
799
800
801
802
803
804
#' @param start_iteration int or None, optional (default=None)
#'                        Start index of the iteration to predict.
#'                        If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#'                      Limit number of iterations in the prediction.
#'                      If None, if the best iteration exists and start_iteration is None or <= 0, the
#'                      best iteration is used; otherwise, all iterations from start_iteration are used.
#'                      If <= 0, all iterations from start_iteration are used (no limits).
805
#' @param rawscore whether the prediction should be returned in the for of original untransformed
806
807
#'                 sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#'                 for logistic regression would result in predictions for log-odds instead of probabilities.
808
#' @param predleaf whether predict leaf index instead.
809
#' @param predcontrib return per-feature contributions for each record.
Guolin Ke's avatar
Guolin Ke committed
810
#' @param header only used for prediction for text file. True if text file has header
811
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
812
#'                prediction outputs per case.
James Lamb's avatar
James Lamb committed
813
814
#' @param ... Additional named arguments passed to the \code{predict()} method of
#'            the \code{lgb.Booster} object passed to \code{object}.
815
816
817
818
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#'         For multiclass classification, either a \code{num_class * nrows(data)} vector or
#'         a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#'         the \code{reshape} value.
819
#'
820
821
#'         When \code{predleaf = TRUE}, the output is a matrix object with the
#'         number of columns corresponding to the number of trees.
822
#'
Guolin Ke's avatar
Guolin Ke committed
823
#' @examples
824
#' \donttest{
825
826
827
828
829
830
831
832
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
833
834
835
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
836
#'   , nrounds = 5L
837
#'   , valids = valids
838
839
#'   , min_data = 1L
#'   , learning_rate = 1.0
840
#' )
841
#' preds <- predict(model, test$data)
842
#' }
Guolin Ke's avatar
Guolin Ke committed
843
#' @export
James Lamb's avatar
James Lamb committed
844
845
predict.lgb.Booster <- function(object,
                                data,
846
                                start_iteration = NULL,
James Lamb's avatar
James Lamb committed
847
848
849
850
851
                                num_iteration = NULL,
                                rawscore = FALSE,
                                predleaf = FALSE,
                                predcontrib = FALSE,
                                header = FALSE,
852
                                reshape = FALSE,
James Lamb's avatar
James Lamb committed
853
                                ...) {
854

855
  if (!lgb.is.Booster(x = object)) {
856
    stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
Guolin Ke's avatar
Guolin Ke committed
857
  }
858

859
  # Return booster predictions
860
861
862
  return(
    object$predict(
      data = data
863
864
865
866
867
868
869
      , start_iteration = start_iteration
      , num_iteration = num_iteration
      , rawscore = rawscore
      , predleaf =  predleaf
      , predcontrib =  predcontrib
      , header = header
      , reshape = reshape
870
871
      , ...
    )
872
  )
Guolin Ke's avatar
Guolin Ke committed
873
874
}

875
876
877
878
#' @name lgb.load
#' @title Load LightGBM model
#' @description  Load LightGBM takes in either a file path or model string.
#'               If both are provided, Load will default to loading from file
Guolin Ke's avatar
Guolin Ke committed
879
#' @param filename path of model file
880
#' @param model_str a str containing the model
881
#'
882
#' @return lgb.Booster
883
#'
Guolin Ke's avatar
Guolin Ke committed
884
#' @examples
885
#' \donttest{
886
887
888
889
890
891
892
893
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
894
895
896
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
897
#'   , nrounds = 5L
898
#'   , valids = valids
899
900
#'   , min_data = 1L
#'   , learning_rate = 1.0
901
#'   , early_stopping_rounds = 3L
902
#' )
903
904
905
#' model_file <- tempfile(fileext = ".txt")
#' lgb.save(model, model_file)
#' load_booster <- lgb.load(filename = model_file)
906
907
#' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string)
908
#' }
Guolin Ke's avatar
Guolin Ke committed
909
#' @export
910
lgb.load <- function(filename = NULL, model_str = NULL) {
911

912
913
  filename_provided <- !is.null(filename)
  model_str_provided <- !is.null(model_str)
914

915
916
917
918
919
920
921
  if (filename_provided) {
    if (!is.character(filename)) {
      stop("lgb.load: filename should be character")
    }
    if (!file.exists(filename)) {
      stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
    }
922
923
    return(invisible(Booster$new(modelfile = filename)))
  }
924

925
926
927
928
  if (model_str_provided) {
    if (!is.character(model_str)) {
      stop("lgb.load: model_str should be character")
    }
929
930
    return(invisible(Booster$new(model_str = model_str)))
  }
931

932
  stop("lgb.load: either filename or model_str must be given")
Guolin Ke's avatar
Guolin Ke committed
933
934
}

935
936
937
#' @name lgb.save
#' @title Save LightGBM model
#' @description Save LightGBM model
Guolin Ke's avatar
Guolin Ke committed
938
939
940
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
941
#'
942
#' @return lgb.Booster
943
#'
Guolin Ke's avatar
Guolin Ke committed
944
#' @examples
945
#' \donttest{
946
947
948
949
950
951
952
953
954
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
955
956
957
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
958
#'   , nrounds = 10L
959
#'   , valids = valids
960
961
962
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
963
#' )
964
#' lgb.save(model, tempfile(fileext = ".txt"))
965
#' }
Guolin Ke's avatar
Guolin Ke committed
966
#' @export
967
lgb.save <- function(booster, filename, num_iteration = NULL) {
968

969
  if (!lgb.is.Booster(x = booster)) {
970
971
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
972

973
974
  if (!(is.character(filename) && length(filename) == 1L)) {
    stop("lgb.save: filename should be a string")
975
  }
976

977
  # Store booster
978
979
980
981
982
983
  return(
    invisible(booster$save_model(
      filename = filename
      , num_iteration = num_iteration
    ))
  )
984

Guolin Ke's avatar
Guolin Ke committed
985
986
}

987
988
989
#' @name lgb.dump
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
Guolin Ke's avatar
Guolin Ke committed
990
991
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
992
#'
Guolin Ke's avatar
Guolin Ke committed
993
#' @return json format of model
994
#'
Guolin Ke's avatar
Guolin Ke committed
995
#' @examples
996
#' \donttest{
997
998
999
1000
1001
1002
1003
1004
1005
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
1006
1007
1008
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1009
#'   , nrounds = 10L
1010
#'   , valids = valids
1011
1012
1013
#'   , min_data = 1L
#'   , learning_rate = 1.0
#'   , early_stopping_rounds = 5L
1014
#' )
1015
#' json_model <- lgb.dump(model)
1016
#' }
Guolin Ke's avatar
Guolin Ke committed
1017
#' @export
1018
lgb.dump <- function(booster, num_iteration = NULL) {
1019

1020
  if (!lgb.is.Booster(x = booster)) {
1021
1022
    stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
  }
1023

1024
  # Return booster at requested iteration
1025
  return(booster$dump_model(num_iteration =  num_iteration))
1026

Guolin Ke's avatar
Guolin Ke committed
1027
1028
}

1029
1030
#' @name lgb.get.eval.result
#' @title Get record evaluation result from booster
1031
1032
#' @description Given a \code{lgb.Booster}, return evaluation results for a
#'              particular metric on a particular dataset.
Guolin Ke's avatar
Guolin Ke committed
1033
#' @param booster Object of class \code{lgb.Booster}
1034
1035
1036
1037
#' @param data_name Name of the dataset to return evaluation results for.
#' @param eval_name Name of the evaluation metric to return results for.
#' @param iters An integer vector of iterations you want to get evaluation results for. If NULL
#'              (the default), evaluation results for all iterations will be returned.
Guolin Ke's avatar
Guolin Ke committed
1038
#' @param is_err TRUE will return evaluation error instead
1039
#'
1040
#' @return numeric vector of evaluation result
1041
#'
1042
#' @examples
1043
#' \donttest{
1044
#' # train a regression model
1045
1046
1047
1048
1049
1050
1051
1052
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' params <- list(objective = "regression", metric = "l2")
#' valids <- list(test = dtest)
1053
1054
1055
#' model <- lgb.train(
#'   params = params
#'   , data = dtrain
1056
#'   , nrounds = 5L
1057
#'   , valids = valids
1058
1059
#'   , min_data = 1L
#'   , learning_rate = 1.0
1060
#' )
1061
1062
1063
1064
1065
1066
1067
1068
#'
#' # Examine valid data_name values
#' print(setdiff(names(model$record_evals), "start_iter"))
#'
#' # Examine valid eval_name values for dataset "test"
#' print(names(model$record_evals[["test"]]))
#'
#' # Get L2 values for "test" dataset
1069
#' lgb.get.eval.result(model, "test", "l2")
1070
#' }
Guolin Ke's avatar
Guolin Ke committed
1071
#' @export
1072
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
1073

1074
  # Check if booster is booster
1075
  if (!lgb.is.Booster(x = booster)) {
1076
    stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
Guolin Ke's avatar
Guolin Ke committed
1077
  }
1078

1079
  # Check if data and evaluation name are characters or not
1080
1081
  if (!is.character(data_name) || !is.character(eval_name)) {
    stop("lgb.get.eval.result: data_name and eval_name should be characters")
Guolin Ke's avatar
Guolin Ke committed
1082
  }
1083

1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
  # NOTE: "start_iter" exists in booster$record_evals but is not a valid data_name
  data_names <- setdiff(names(booster$record_evals), "start_iter")
  if (!(data_name %in% data_names)) {
    stop(paste0(
      "lgb.get.eval.result: data_name "
      , shQuote(data_name)
      , " not found. Only the following datasets exist in record evals: ["
      , paste(data_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1094
  }
1095

1096
  # Check if evaluation result is existing
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
  eval_names <- names(booster$record_evals[[data_name]])
  if (!(eval_name %in% eval_names)) {
    stop(paste0(
      "lgb.get.eval.result: eval_name "
      , shQuote(eval_name)
      , " not found. Only the following eval_names exist for dataset "
      , shQuote(data_name)
      , ": ["
      , paste(eval_names, collapse = ", ")
      , "]"
    ))
Guolin Ke's avatar
Guolin Ke committed
1108
1109
    stop("lgb.get.eval.result: wrong eval name")
  }
1110

1111
  result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_KEY()]]
1112

1113
  # Check if error is requested
1114
  if (is_err) {
1115
    result <- booster$record_evals[[data_name]][[eval_name]][[.EVAL_ERR_KEY()]]
Guolin Ke's avatar
Guolin Ke committed
1116
  }
1117

1118
  # Check if iteration is non existant
1119
  if (is.null(iters)) {
Guolin Ke's avatar
Guolin Ke committed
1120
1121
    return(as.numeric(result))
  }
1122

1123
  # Parse iteration and booster delta
Guolin Ke's avatar
Guolin Ke committed
1124
  iters <- as.integer(iters)
1125
  delta <- booster$record_evals$start_iter - 1.0
Guolin Ke's avatar
Guolin Ke committed
1126
  iters <- iters - delta
1127

1128
  # Return requested result
1129
  return(as.numeric(result[iters]))
Guolin Ke's avatar
Guolin Ke committed
1130
}