lgb.Predictor.R 6.61 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom methods is
James Lamb's avatar
James Lamb committed
2
#' @importFrom R6 R6Class
3
#' @importFrom utils read.delim
James Lamb's avatar
James Lamb committed
4
5
Predictor <- R6::R6Class(

6
  classname = "lgb.Predictor",
7
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
8
  public = list(
James Lamb's avatar
James Lamb committed
9

10
    # Finalize will free up the handles
Guolin Ke's avatar
Guolin Ke committed
11
    finalize = function() {
James Lamb's avatar
James Lamb committed
12

13
      # Check the need for freeing handle
14
      if (private$need_free_handle && !lgb.is.null.handle(x = private$handle)) {
James Lamb's avatar
James Lamb committed
15

16
        # Freeing up handle
17
18
19
        call_state <- 0L
        .Call(
          LGBM_BoosterFree_R
20
          , private$handle
21
          , call_state
22
        )
Guolin Ke's avatar
Guolin Ke committed
23
        private$handle <- NULL
James Lamb's avatar
James Lamb committed
24

Guolin Ke's avatar
Guolin Ke committed
25
      }
James Lamb's avatar
James Lamb committed
26

27
28
      return(invisible(NULL))

29
    },
James Lamb's avatar
James Lamb committed
30

31
    # Initialize will create a starter model
32
33
    initialize = function(modelfile, ...) {
      params <- list(...)
34
      private$params <- lgb.params2str(params = params)
35
      # Create new lgb handle
36
      handle <- lgb.null.handle()
James Lamb's avatar
James Lamb committed
37

38
      # Check if handle is a character
39
      if (is.character(modelfile)) {
James Lamb's avatar
James Lamb committed
40

41
        # Create handle on it
42
43
44
        call_state <- 0L
        .Call(
          LGBM_BoosterCreateFromModelfile_R
45
          , lgb.c_str(x = modelfile)
46
47
          , handle
          , call_state
48
        )
49
        private$need_free_handle <- TRUE
James Lamb's avatar
James Lamb committed
50
51
52

      } else if (methods::is(modelfile, "lgb.Booster.handle")) {

53
        # Check if model file is a booster handle already
Guolin Ke's avatar
Guolin Ke committed
54
        handle <- modelfile
55
        private$need_free_handle <- FALSE
James Lamb's avatar
James Lamb committed
56

Guolin Ke's avatar
Guolin Ke committed
57
      } else {
James Lamb's avatar
James Lamb committed
58

59
        stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
James Lamb's avatar
James Lamb committed
60

Guolin Ke's avatar
Guolin Ke committed
61
      }
James Lamb's avatar
James Lamb committed
62

63
      # Override class and store it
Guolin Ke's avatar
Guolin Ke committed
64
65
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
James Lamb's avatar
James Lamb committed
66

67
68
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
69
    },
James Lamb's avatar
James Lamb committed
70

71
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
72
    current_iter = function() {
James Lamb's avatar
James Lamb committed
73

74
      cur_iter <- 0L
75
76
77
78
79
80
      call_state <- 0L
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
        , call_state
81
      )
82
      return(cur_iter)
James Lamb's avatar
James Lamb committed
83

Guolin Ke's avatar
Guolin Ke committed
84
    },
James Lamb's avatar
James Lamb committed
85

86
87
    # Predict from data
    predict = function(data,
88
                       start_iteration = NULL,
89
90
91
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
92
                       predcontrib = FALSE,
93
94
                       header = FALSE,
                       reshape = FALSE) {
James Lamb's avatar
James Lamb committed
95

96
97
      # Check if number of iterations is existing - if not, then set it to -1 (use all)
      if (is.null(num_iteration)) {
98
        num_iteration <- -1L
99
      }
100
101
102
103
      # Check if start iterations is existing - if not, then set it to 0 (start from the first iteration)
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
James Lamb's avatar
James Lamb committed
104

105
      num_row <- 0L
James Lamb's avatar
James Lamb committed
106

Laurae's avatar
Laurae committed
107
      # Check if data is a file name and not a matrix
108
      if (identical(class(data), "character") && length(data) == 1L) {
James Lamb's avatar
James Lamb committed
109

110
        # Data is a filename, create a temporary file with a "lightgbm_" pattern in it
Guolin Ke's avatar
Guolin Ke committed
111
        tmp_filename <- tempfile(pattern = "lightgbm_")
112
        on.exit(unlink(tmp_filename), add = TRUE)
James Lamb's avatar
James Lamb committed
113

114
        # Predict from temporary file
115
116
117
        call_state <- 0L
        .Call(
          LGBM_BoosterPredictForFile_R
118
119
120
121
122
123
          , private$handle
          , data
          , as.integer(header)
          , as.integer(rawscore)
          , as.integer(predleaf)
          , as.integer(predcontrib)
124
          , as.integer(start_iteration)
125
126
          , as.integer(num_iteration)
          , private$params
127
          , lgb.c_str(x = tmp_filename)
128
          , call_state
129
        )
James Lamb's avatar
James Lamb committed
130

131
        # Get predictions from file
132
        preds <- utils::read.delim(tmp_filename, header = FALSE, sep = "\t")
Guolin Ke's avatar
Guolin Ke committed
133
        num_row <- nrow(preds)
134
        preds <- as.vector(t(preds))
James Lamb's avatar
James Lamb committed
135

Guolin Ke's avatar
Guolin Ke committed
136
      } else {
James Lamb's avatar
James Lamb committed
137

138
        # Not a file, we need to predict from R object
Guolin Ke's avatar
Guolin Ke committed
139
        num_row <- nrow(data)
James Lamb's avatar
James Lamb committed
140

141
        npred <- 0L
James Lamb's avatar
James Lamb committed
142

143
        # Check number of predictions to do
144
145
146
        call_state <- 0L
        .Call(
          LGBM_BoosterCalcNumPredict_R
147
148
149
150
151
          , private$handle
          , as.integer(num_row)
          , as.integer(rawscore)
          , as.integer(predleaf)
          , as.integer(predcontrib)
152
          , as.integer(start_iteration)
153
          , as.integer(num_iteration)
154
155
          , npred
          , call_state
156
        )
James Lamb's avatar
James Lamb committed
157

158
159
        # Pre-allocate empty vector
        preds <- numeric(npred)
James Lamb's avatar
James Lamb committed
160

161
        # Check if data is a matrix
Guolin Ke's avatar
Guolin Ke committed
162
        if (is.matrix(data)) {
163
164
          # this if() prevents the memory and computational costs
          # of converting something that is already "double" to "double"
165
166
167
          if (storage.mode(data) != "double") {
            storage.mode(data) <- "double"
          }
168
169
170
          call_state <- 0L
          .Call(
            LGBM_BoosterPredictForMat_R
171
172
173
174
175
176
177
            , private$handle
            , data
            , as.integer(nrow(data))
            , as.integer(ncol(data))
            , as.integer(rawscore)
            , as.integer(predleaf)
            , as.integer(predcontrib)
178
            , as.integer(start_iteration)
179
180
            , as.integer(num_iteration)
            , private$params
181
182
            , preds
            , call_state
183
          )
James Lamb's avatar
James Lamb committed
184
185

        } else if (methods::is(data, "dgCMatrix")) {
186
          if (length(data@p) > 2147483647L) {
187
188
            stop("Cannot support large CSC matrix")
          }
189
          # Check if data is a dgCMatrix (sparse matrix, column compressed format)
190
191
192
          call_state <- 0L
          .Call(
            LGBM_BoosterPredictForCSC_R
193
194
195
196
197
198
199
200
201
202
            , private$handle
            , data@p
            , data@i
            , data@x
            , length(data@p)
            , length(data@x)
            , nrow(data)
            , as.integer(rawscore)
            , as.integer(predleaf)
            , as.integer(predcontrib)
203
            , as.integer(start_iteration)
204
205
            , as.integer(num_iteration)
            , private$params
206
207
            , preds
            , call_state
208
          )
James Lamb's avatar
James Lamb committed
209

Guolin Ke's avatar
Guolin Ke committed
210
        } else {
James Lamb's avatar
James Lamb committed
211

212
          stop("predict: cannot predict on data of class ", sQuote(class(data)))
James Lamb's avatar
James Lamb committed
213

214
        }
Guolin Ke's avatar
Guolin Ke committed
215
      }
James Lamb's avatar
James Lamb committed
216

217
      # Check if number of rows is strange (not a multiple of the dataset rows)
218
      if (length(preds) %% num_row != 0L) {
219
220
221
        stop(
          "predict: prediction length "
          , sQuote(length(preds))
222
          , " is not a multiple of nrows(data): "
223
224
          , sQuote(num_row)
        )
Guolin Ke's avatar
Guolin Ke committed
225
      }
James Lamb's avatar
James Lamb committed
226

227
      # Get number of cases per row
Guolin Ke's avatar
Guolin Ke committed
228
      npred_per_case <- length(preds) / num_row
James Lamb's avatar
James Lamb committed
229
230


231
      # Data reshaping
James Lamb's avatar
James Lamb committed
232

233
      if (predleaf | predcontrib) {
James Lamb's avatar
James Lamb committed
234

235
        # Predict leaves only, reshaping is mandatory
236
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
James Lamb's avatar
James Lamb committed
237

238
      } else if (reshape && npred_per_case > 1L) {
James Lamb's avatar
James Lamb committed
239

240
        # Predict with data reshaping
241
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
James Lamb's avatar
James Lamb committed
242

243
      }
James Lamb's avatar
James Lamb committed
244

245
      return(preds)
James Lamb's avatar
James Lamb committed
246

Guolin Ke's avatar
Guolin Ke committed
247
    }
James Lamb's avatar
James Lamb committed
248

249
  ),
250
251
252
253
254
  private = list(
    handle = NULL
    , need_free_handle = FALSE
    , params = ""
  )
Guolin Ke's avatar
Guolin Ke committed
255
)