lgb.Predictor.R 6.06 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom methods is
James Lamb's avatar
James Lamb committed
2
#' @importFrom R6 R6Class
3
#' @importFrom utils read.delim
James Lamb's avatar
James Lamb committed
4
5
Predictor <- R6::R6Class(

6
  classname = "lgb.Predictor",
7
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
8
  public = list(
James Lamb's avatar
James Lamb committed
9

10
    # Finalize will free up the handles
Guolin Ke's avatar
Guolin Ke committed
11
    finalize = function() {
James Lamb's avatar
James Lamb committed
12

13
      # Check the need for freeing handle
14
      if (private$need_free_handle) {
James Lamb's avatar
James Lamb committed
15

16
17
        .Call(
          LGBM_BoosterFree_R
18
19
          , private$handle
        )
Guolin Ke's avatar
Guolin Ke committed
20
        private$handle <- NULL
James Lamb's avatar
James Lamb committed
21

Guolin Ke's avatar
Guolin Ke committed
22
      }
James Lamb's avatar
James Lamb committed
23

24
25
      return(invisible(NULL))

26
    },
James Lamb's avatar
James Lamb committed
27

28
    # Initialize will create a starter model
29
    initialize = function(modelfile, params = list()) {
30
      private$params <- lgb.params2str(params = params)
31
      handle <- NULL
James Lamb's avatar
James Lamb committed
32

33
      if (is.character(modelfile)) {
James Lamb's avatar
James Lamb committed
34

35
        # Create handle on it
36
        handle <- .Call(
37
          LGBM_BoosterCreateFromModelfile_R
38
          , path.expand(modelfile)
39
        )
40
        private$need_free_handle <- TRUE
James Lamb's avatar
James Lamb committed
41

42
      } else if (methods::is(modelfile, "lgb.Booster.handle") || inherits(modelfile, "externalptr")) {
James Lamb's avatar
James Lamb committed
43

44
        # Check if model file is a booster handle already
Guolin Ke's avatar
Guolin Ke committed
45
        handle <- modelfile
46
        private$need_free_handle <- FALSE
James Lamb's avatar
James Lamb committed
47

48
49
50
51
52
      } else if (lgb.is.Booster(modelfile)) {

        handle <- modelfile$get_handle()
        private$need_free_handle <- FALSE

Guolin Ke's avatar
Guolin Ke committed
53
      } else {
James Lamb's avatar
James Lamb committed
54

55
        stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
James Lamb's avatar
James Lamb committed
56

Guolin Ke's avatar
Guolin Ke committed
57
      }
James Lamb's avatar
James Lamb committed
58

59
      # Override class and store it
Guolin Ke's avatar
Guolin Ke committed
60
61
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
James Lamb's avatar
James Lamb committed
62

63
64
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
65
    },
James Lamb's avatar
James Lamb committed
66

67
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
68
    current_iter = function() {
James Lamb's avatar
James Lamb committed
69

70
      cur_iter <- 0L
71
72
73
74
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
75
      )
76
      return(cur_iter)
James Lamb's avatar
James Lamb committed
77

Guolin Ke's avatar
Guolin Ke committed
78
    },
James Lamb's avatar
James Lamb committed
79

80
81
    # Predict from data
    predict = function(data,
82
                       start_iteration = NULL,
83
84
85
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
86
                       predcontrib = FALSE,
87
                       header = FALSE) {
James Lamb's avatar
James Lamb committed
88

89
90
      # Check if number of iterations is existing - if not, then set it to -1 (use all)
      if (is.null(num_iteration)) {
91
        num_iteration <- -1L
92
      }
93
94
95
96
      # Check if start iterations is existing - if not, then set it to 0 (start from the first iteration)
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
James Lamb's avatar
James Lamb committed
97

98
      num_row <- 0L
James Lamb's avatar
James Lamb committed
99

Laurae's avatar
Laurae committed
100
      # Check if data is a file name and not a matrix
101
      if (identical(class(data), "character") && length(data) == 1L) {
James Lamb's avatar
James Lamb committed
102

103
104
        data <- path.expand(data)

105
        # Data is a filename, create a temporary file with a "lightgbm_" pattern in it
Guolin Ke's avatar
Guolin Ke committed
106
        tmp_filename <- tempfile(pattern = "lightgbm_")
107
        on.exit(unlink(tmp_filename), add = TRUE)
James Lamb's avatar
James Lamb committed
108

109
        # Predict from temporary file
110
111
        .Call(
          LGBM_BoosterPredictForFile_R
112
113
114
115
116
117
          , private$handle
          , data
          , as.integer(header)
          , as.integer(rawscore)
          , as.integer(predleaf)
          , as.integer(predcontrib)
118
          , as.integer(start_iteration)
119
120
          , as.integer(num_iteration)
          , private$params
121
          , tmp_filename
122
        )
James Lamb's avatar
James Lamb committed
123

124
        # Get predictions from file
125
        preds <- utils::read.delim(tmp_filename, header = FALSE, sep = "\t")
Guolin Ke's avatar
Guolin Ke committed
126
        num_row <- nrow(preds)
127
        preds <- as.vector(t(preds))
James Lamb's avatar
James Lamb committed
128

Guolin Ke's avatar
Guolin Ke committed
129
      } else {
James Lamb's avatar
James Lamb committed
130

131
        # Not a file, we need to predict from R object
Guolin Ke's avatar
Guolin Ke committed
132
        num_row <- nrow(data)
James Lamb's avatar
James Lamb committed
133

134
        npred <- 0L
James Lamb's avatar
James Lamb committed
135

136
        # Check number of predictions to do
137
138
        .Call(
          LGBM_BoosterCalcNumPredict_R
139
140
141
142
143
          , private$handle
          , as.integer(num_row)
          , as.integer(rawscore)
          , as.integer(predleaf)
          , as.integer(predcontrib)
144
          , as.integer(start_iteration)
145
          , as.integer(num_iteration)
146
          , npred
147
        )
James Lamb's avatar
James Lamb committed
148

149
150
        # Pre-allocate empty vector
        preds <- numeric(npred)
James Lamb's avatar
James Lamb committed
151

152
        # Check if data is a matrix
Guolin Ke's avatar
Guolin Ke committed
153
        if (is.matrix(data)) {
154
155
          # this if() prevents the memory and computational costs
          # of converting something that is already "double" to "double"
156
157
158
          if (storage.mode(data) != "double") {
            storage.mode(data) <- "double"
          }
159
160
          .Call(
            LGBM_BoosterPredictForMat_R
161
162
163
164
165
166
167
            , private$handle
            , data
            , as.integer(nrow(data))
            , as.integer(ncol(data))
            , as.integer(rawscore)
            , as.integer(predleaf)
            , as.integer(predcontrib)
168
            , as.integer(start_iteration)
169
170
            , as.integer(num_iteration)
            , private$params
171
            , preds
172
          )
James Lamb's avatar
James Lamb committed
173
174

        } else if (methods::is(data, "dgCMatrix")) {
175
          if (length(data@p) > 2147483647L) {
176
177
            stop("Cannot support large CSC matrix")
          }
178
          # Check if data is a dgCMatrix (sparse matrix, column compressed format)
179
180
          .Call(
            LGBM_BoosterPredictForCSC_R
181
182
183
184
185
186
187
188
189
190
            , private$handle
            , data@p
            , data@i
            , data@x
            , length(data@p)
            , length(data@x)
            , nrow(data)
            , as.integer(rawscore)
            , as.integer(predleaf)
            , as.integer(predcontrib)
191
            , as.integer(start_iteration)
192
193
            , as.integer(num_iteration)
            , private$params
194
            , preds
195
          )
James Lamb's avatar
James Lamb committed
196

Guolin Ke's avatar
Guolin Ke committed
197
        } else {
James Lamb's avatar
James Lamb committed
198

199
          stop("predict: cannot predict on data of class ", sQuote(class(data)))
James Lamb's avatar
James Lamb committed
200

201
        }
Guolin Ke's avatar
Guolin Ke committed
202
      }
James Lamb's avatar
James Lamb committed
203

204
      # Check if number of rows is strange (not a multiple of the dataset rows)
205
      if (length(preds) %% num_row != 0L) {
206
207
208
        stop(
          "predict: prediction length "
          , sQuote(length(preds))
209
          , " is not a multiple of nrows(data): "
210
211
          , sQuote(num_row)
        )
Guolin Ke's avatar
Guolin Ke committed
212
      }
James Lamb's avatar
James Lamb committed
213

214
      # Get number of cases per row
Guolin Ke's avatar
Guolin Ke committed
215
      npred_per_case <- length(preds) / num_row
James Lamb's avatar
James Lamb committed
216

217
      # Data reshaping
218
      if (npred_per_case > 1L || predleaf || predcontrib) {
219
220
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
      }
James Lamb's avatar
James Lamb committed
221

222
      return(preds)
Guolin Ke's avatar
Guolin Ke committed
223
    }
James Lamb's avatar
James Lamb committed
224

225
  ),
226
227
228
229
230
  private = list(
    handle = NULL
    , need_free_handle = FALSE
    , params = ""
  )
Guolin Ke's avatar
Guolin Ke committed
231
)