lgb.Predictor.R 6.51 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom methods is
James Lamb's avatar
James Lamb committed
2
#' @importFrom R6 R6Class
3
#' @importFrom utils read.delim
James Lamb's avatar
James Lamb committed
4
5
Predictor <- R6::R6Class(

6
  classname = "lgb.Predictor",
7
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
8
  public = list(
James Lamb's avatar
James Lamb committed
9

10
    # Finalize will free up the handles
Guolin Ke's avatar
Guolin Ke committed
11
    finalize = function() {
James Lamb's avatar
James Lamb committed
12

13
      # Check the need for freeing handle
14
      if (private$need_free_handle && !lgb.is.null.handle(x = private$handle)) {
James Lamb's avatar
James Lamb committed
15

16
        # Freeing up handle
17
        lgb.call(
18
          fun_name = "LGBM_BoosterFree_R"
19
20
21
          , ret = NULL
          , private$handle
        )
Guolin Ke's avatar
Guolin Ke committed
22
        private$handle <- NULL
James Lamb's avatar
James Lamb committed
23

Guolin Ke's avatar
Guolin Ke committed
24
      }
James Lamb's avatar
James Lamb committed
25

26
27
      return(invisible(NULL))

28
    },
James Lamb's avatar
James Lamb committed
29

30
    # Initialize will create a starter model
31
32
    initialize = function(modelfile, ...) {
      params <- list(...)
33
      private$params <- lgb.params2str(params = params)
34
      # Create new lgb handle
35
      handle <- lgb.null.handle()
James Lamb's avatar
James Lamb committed
36

37
      # Check if handle is a character
38
      if (is.character(modelfile)) {
James Lamb's avatar
James Lamb committed
39

40
        # Create handle on it
41
        handle <- lgb.call(
42
          fun_name = "LGBM_BoosterCreateFromModelfile_R"
43
          , ret = handle
44
          , lgb.c_str(x = modelfile)
45
        )
46
        private$need_free_handle <- TRUE
James Lamb's avatar
James Lamb committed
47
48
49

      } else if (methods::is(modelfile, "lgb.Booster.handle")) {

50
        # Check if model file is a booster handle already
Guolin Ke's avatar
Guolin Ke committed
51
        handle <- modelfile
52
        private$need_free_handle <- FALSE
James Lamb's avatar
James Lamb committed
53

Guolin Ke's avatar
Guolin Ke committed
54
      } else {
James Lamb's avatar
James Lamb committed
55

56
        stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
James Lamb's avatar
James Lamb committed
57

Guolin Ke's avatar
Guolin Ke committed
58
      }
James Lamb's avatar
James Lamb committed
59

60
      # Override class and store it
Guolin Ke's avatar
Guolin Ke committed
61
62
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
James Lamb's avatar
James Lamb committed
63

64
65
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
66
    },
James Lamb's avatar
James Lamb committed
67

68
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
69
    current_iter = function() {
James Lamb's avatar
James Lamb committed
70

71
      cur_iter <- 0L
72
73
74
75
76
77
      return(
        lgb.call(
          fun_name = "LGBM_BoosterGetCurrentIteration_R"
          , ret = cur_iter
          , private$handle
        )
78
      )
James Lamb's avatar
James Lamb committed
79

Guolin Ke's avatar
Guolin Ke committed
80
    },
James Lamb's avatar
James Lamb committed
81

82
83
    # Predict from data
    predict = function(data,
84
                       start_iteration = NULL,
85
86
87
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
88
                       predcontrib = FALSE,
89
90
                       header = FALSE,
                       reshape = FALSE) {
James Lamb's avatar
James Lamb committed
91

92
93
      # Check if number of iterations is existing - if not, then set it to -1 (use all)
      if (is.null(num_iteration)) {
94
        num_iteration <- -1L
95
      }
96
97
98
99
      # Check if start iterations is existing - if not, then set it to 0 (start from the first iteration)
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
James Lamb's avatar
James Lamb committed
100

101
      num_row <- 0L
James Lamb's avatar
James Lamb committed
102

Laurae's avatar
Laurae committed
103
      # Check if data is a file name and not a matrix
104
      if (identical(class(data), "character") && length(data) == 1L) {
James Lamb's avatar
James Lamb committed
105

106
        # Data is a filename, create a temporary file with a "lightgbm_" pattern in it
Guolin Ke's avatar
Guolin Ke committed
107
        tmp_filename <- tempfile(pattern = "lightgbm_")
108
        on.exit(unlink(tmp_filename), add = TRUE)
James Lamb's avatar
James Lamb committed
109

110
        # Predict from temporary file
111
        lgb.call(
112
          fun_name = "LGBM_BoosterPredictForFile_R"
113
114
115
116
117
118
119
          , ret = NULL
          , private$handle
          , data
          , as.integer(header)
          , as.integer(rawscore)
          , as.integer(predleaf)
          , as.integer(predcontrib)
120
          , as.integer(start_iteration)
121
122
          , as.integer(num_iteration)
          , private$params
123
          , lgb.c_str(x = tmp_filename)
124
        )
James Lamb's avatar
James Lamb committed
125

126
        # Get predictions from file
127
        preds <- utils::read.delim(tmp_filename, header = FALSE, sep = "\t")
Guolin Ke's avatar
Guolin Ke committed
128
        num_row <- nrow(preds)
129
        preds <- as.vector(t(preds))
James Lamb's avatar
James Lamb committed
130

Guolin Ke's avatar
Guolin Ke committed
131
      } else {
James Lamb's avatar
James Lamb committed
132

133
        # Not a file, we need to predict from R object
Guolin Ke's avatar
Guolin Ke committed
134
        num_row <- nrow(data)
James Lamb's avatar
James Lamb committed
135

136
        npred <- 0L
James Lamb's avatar
James Lamb committed
137

138
        # Check number of predictions to do
139
        npred <- lgb.call(
140
          fun_name = "LGBM_BoosterCalcNumPredict_R"
141
142
143
144
145
146
          , ret = npred
          , private$handle
          , as.integer(num_row)
          , as.integer(rawscore)
          , as.integer(predleaf)
          , as.integer(predcontrib)
147
          , as.integer(start_iteration)
148
149
          , as.integer(num_iteration)
        )
James Lamb's avatar
James Lamb committed
150

151
152
        # Pre-allocate empty vector
        preds <- numeric(npred)
James Lamb's avatar
James Lamb committed
153

154
        # Check if data is a matrix
Guolin Ke's avatar
Guolin Ke committed
155
        if (is.matrix(data)) {
156
157
          # this if() prevents the memory and computational costs
          # of converting something that is already "double" to "double"
158
159
160
          if (storage.mode(data) != "double") {
            storage.mode(data) <- "double"
          }
161
          preds <- lgb.call(
162
            fun_name = "LGBM_BoosterPredictForMat_R"
163
164
165
166
167
168
169
170
            , ret = preds
            , private$handle
            , data
            , as.integer(nrow(data))
            , as.integer(ncol(data))
            , as.integer(rawscore)
            , as.integer(predleaf)
            , as.integer(predcontrib)
171
            , as.integer(start_iteration)
172
173
174
            , as.integer(num_iteration)
            , private$params
          )
James Lamb's avatar
James Lamb committed
175
176

        } else if (methods::is(data, "dgCMatrix")) {
177
          if (length(data@p) > 2147483647L) {
178
179
            stop("Cannot support large CSC matrix")
          }
180
          # Check if data is a dgCMatrix (sparse matrix, column compressed format)
181
          preds <- lgb.call(
182
            fun_name = "LGBM_BoosterPredictForCSC_R"
183
184
185
186
187
188
189
190
191
192
193
            , ret = preds
            , private$handle
            , data@p
            , data@i
            , data@x
            , length(data@p)
            , length(data@x)
            , nrow(data)
            , as.integer(rawscore)
            , as.integer(predleaf)
            , as.integer(predcontrib)
194
            , as.integer(start_iteration)
195
196
197
            , as.integer(num_iteration)
            , private$params
          )
James Lamb's avatar
James Lamb committed
198

Guolin Ke's avatar
Guolin Ke committed
199
        } else {
James Lamb's avatar
James Lamb committed
200

201
          stop("predict: cannot predict on data of class ", sQuote(class(data)))
James Lamb's avatar
James Lamb committed
202

203
        }
Guolin Ke's avatar
Guolin Ke committed
204
      }
James Lamb's avatar
James Lamb committed
205

206
      # Check if number of rows is strange (not a multiple of the dataset rows)
207
      if (length(preds) %% num_row != 0L) {
208
209
210
        stop(
          "predict: prediction length "
          , sQuote(length(preds))
211
          , " is not a multiple of nrows(data): "
212
213
          , sQuote(num_row)
        )
Guolin Ke's avatar
Guolin Ke committed
214
      }
James Lamb's avatar
James Lamb committed
215

216
      # Get number of cases per row
Guolin Ke's avatar
Guolin Ke committed
217
      npred_per_case <- length(preds) / num_row
James Lamb's avatar
James Lamb committed
218
219


220
      # Data reshaping
James Lamb's avatar
James Lamb committed
221

222
      if (predleaf | predcontrib) {
James Lamb's avatar
James Lamb committed
223

224
        # Predict leaves only, reshaping is mandatory
225
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
James Lamb's avatar
James Lamb committed
226

227
      } else if (reshape && npred_per_case > 1L) {
James Lamb's avatar
James Lamb committed
228

229
        # Predict with data reshaping
230
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
James Lamb's avatar
James Lamb committed
231

232
      }
James Lamb's avatar
James Lamb committed
233

234
      return(preds)
James Lamb's avatar
James Lamb committed
235

Guolin Ke's avatar
Guolin Ke committed
236
    }
James Lamb's avatar
James Lamb committed
237

238
  ),
239
240
241
242
243
  private = list(
    handle = NULL
    , need_free_handle = FALSE
    , params = ""
  )
Guolin Ke's avatar
Guolin Ke committed
244
)