lgb.Predictor.R 3.32 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
Predictor <- R6Class(
  "lgb.Predictor",
3
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
4
5
  public = list(
    finalize = function() {
6
7
8
      if (private$need_free_handle && !lgb.is.null.handle(private$handle)) {
        cat("free booster handle\n")
        lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
Guolin Ke's avatar
Guolin Ke committed
9
10
        private$handle <- NULL
      }
11
    },
Guolin Ke's avatar
Guolin Ke committed
12
13
    initialize = function(modelfile) {
      handle <- lgb.new.handle()
14
15
16
17
      if (is.character(modelfile)) {
        handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R", ret = handle, lgb.c_str(modelfile))
        private$need_free_handle <- TRUE
      } else if (is(modelfile, "lgb.Booster.handle")) {
Guolin Ke's avatar
Guolin Ke committed
18
        handle <- modelfile
19
        private$need_free_handle <- FALSE
Guolin Ke's avatar
Guolin Ke committed
20
      } else {
21
        stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
26
      }
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
    },
    current_iter = function() {
27
28
      cur_iter <- 0L
      lgb.call("LGBM_BoosterGetCurrentIteration_R",  ret = cur_iter, private$handle)
Guolin Ke's avatar
Guolin Ke committed
29
    },
30
31
    predict = function(data, num_iteration = NULL, rawscore = FALSE,
      predleaf = FALSE, header = FALSE, reshape = FALSE) {
Guolin Ke's avatar
Guolin Ke committed
32

33
      if (is.null(num_iteration)) { num_iteration <- -1 }
34
      num_row <- 0L
35
      if (is.character(data)) {
Guolin Ke's avatar
Guolin Ke committed
36
        tmp_filename <- tempfile(pattern = "lightgbm_")
37
38
39
        on.exit(unlink(tmp_filename), add = TRUE)
        lgb.call("LGBM_BoosterPredictForFile_R", ret = NULL, private$handle, data,
          as.integer(header),
Guolin Ke's avatar
Guolin Ke committed
40
41
42
43
          as.integer(rawscore),
          as.integer(predleaf),
          as.integer(num_iteration),
          lgb.c_str(tmp_filename))
44
        preds   <- read.delim(tmp_filename, header = FALSE, seq = "\t")
Guolin Ke's avatar
Guolin Ke committed
45
        num_row <- nrow(preds)
46
        preds   <- as.vector(t(preds))
Guolin Ke's avatar
Guolin Ke committed
47
48
      } else {
        num_row <- nrow(data)
49
        npred   <- 0L
50
        npred   <- lgb.call("LGBM_BoosterCalcNumPredict_R", ret = npred,
Guolin Ke's avatar
Guolin Ke committed
51
52
53
54
55
56
57
58
          private$handle,
          as.integer(num_row),
          as.integer(rawscore),
          as.integer(predleaf),
          as.integer(num_iteration))
        # allocte space for prediction
        preds <- rep(0.0, npred)
        if (is.matrix(data)) {
59
60
          preds <- lgb.call("LGBM_BoosterPredictForMat_R", ret = preds,
            private$handle,
Guolin Ke's avatar
Guolin Ke committed
61
62
63
64
65
66
            data,
            as.integer(nrow(data)),
            as.integer(ncol(data)),
            as.integer(rawscore),
            as.integer(predleaf),
            as.integer(num_iteration))
67
68
69
70
71
        } else if (is(data, "dgCMatrix")) {
          preds <- lgb.call("LGBM_BoosterPredictForCSC_R", ret = preds,
            private$handle,
            data@p,
            data@i,
Guolin Ke's avatar
Guolin Ke committed
72
73
74
75
76
77
78
79
            data@x,
            length(data@p),
            length(data@x),
            nrow(data),
            as.integer(rawscore),
            as.integer(predleaf),
            as.integer(num_iteration))
        } else {
80
81
          stop("predict: cannot predict on data of class ", sQuote(class(data)))
        }
Guolin Ke's avatar
Guolin Ke committed
82
83
84
      }

      if (length(preds) %% num_row != 0) {
85
        stop("predict: prediction length ", sQuote(length(preds))," is not a multiple of nrows(data): ", sQuote(num_row))
Guolin Ke's avatar
Guolin Ke committed
86
87
      }
      npred_per_case <- length(preds) / num_row
88
89
      if (reshape && npred_per_case > 1) { preds <- matrix(preds, ncol = npred_per_case) }
      preds
Guolin Ke's avatar
Guolin Ke committed
90
    }
91
92
  ),
  private = list( handle = NULL, need_free_handle = FALSE )
Guolin Ke's avatar
Guolin Ke committed
93
)