lgb.Predictor.R 6.13 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom methods is
James Lamb's avatar
James Lamb committed
2
#' @importFrom R6 R6Class
3
#' @importFrom utils read.delim
James Lamb's avatar
James Lamb committed
4
5
Predictor <- R6::R6Class(

6
  classname = "lgb.Predictor",
7
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
8
  public = list(
James Lamb's avatar
James Lamb committed
9

10
    # Finalize will free up the handles
Guolin Ke's avatar
Guolin Ke committed
11
    finalize = function() {
James Lamb's avatar
James Lamb committed
12

13
      # Check the need for freeing handle
14
      if (private$need_free_handle) {
James Lamb's avatar
James Lamb committed
15

16
17
        .Call(
          LGBM_BoosterFree_R
18
19
          , private$handle
        )
Guolin Ke's avatar
Guolin Ke committed
20
        private$handle <- NULL
James Lamb's avatar
James Lamb committed
21

Guolin Ke's avatar
Guolin Ke committed
22
      }
James Lamb's avatar
James Lamb committed
23

24
25
      return(invisible(NULL))

26
    },
James Lamb's avatar
James Lamb committed
27

28
    # Initialize will create a starter model
29
    initialize = function(modelfile, params = list()) {
30
      private$params <- lgb.params2str(params = params)
31
      handle <- NULL
James Lamb's avatar
James Lamb committed
32

33
      if (is.character(modelfile)) {
James Lamb's avatar
James Lamb committed
34

35
        # Create handle on it
36
        handle <- .Call(
37
          LGBM_BoosterCreateFromModelfile_R
38
          , path.expand(modelfile)
39
        )
40
        private$need_free_handle <- TRUE
James Lamb's avatar
James Lamb committed
41
42
43

      } else if (methods::is(modelfile, "lgb.Booster.handle")) {

44
        # Check if model file is a booster handle already
Guolin Ke's avatar
Guolin Ke committed
45
        handle <- modelfile
46
        private$need_free_handle <- FALSE
James Lamb's avatar
James Lamb committed
47

Guolin Ke's avatar
Guolin Ke committed
48
      } else {
James Lamb's avatar
James Lamb committed
49

50
        stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
James Lamb's avatar
James Lamb committed
51

Guolin Ke's avatar
Guolin Ke committed
52
      }
James Lamb's avatar
James Lamb committed
53

54
      # Override class and store it
Guolin Ke's avatar
Guolin Ke committed
55
56
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
James Lamb's avatar
James Lamb committed
57

58
59
      return(invisible(NULL))

Guolin Ke's avatar
Guolin Ke committed
60
    },
James Lamb's avatar
James Lamb committed
61

62
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
63
    current_iter = function() {
James Lamb's avatar
James Lamb committed
64

65
      cur_iter <- 0L
66
67
68
69
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
70
      )
71
      return(cur_iter)
James Lamb's avatar
James Lamb committed
72

Guolin Ke's avatar
Guolin Ke committed
73
    },
James Lamb's avatar
James Lamb committed
74

75
76
    # Predict from data
    predict = function(data,
77
                       start_iteration = NULL,
78
79
80
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
81
                       predcontrib = FALSE,
82
83
                       header = FALSE,
                       reshape = FALSE) {
James Lamb's avatar
James Lamb committed
84

85
86
      # Check if number of iterations is existing - if not, then set it to -1 (use all)
      if (is.null(num_iteration)) {
87
        num_iteration <- -1L
88
      }
89
90
91
92
      # Check if start iterations is existing - if not, then set it to 0 (start from the first iteration)
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }
James Lamb's avatar
James Lamb committed
93

94
      num_row <- 0L
James Lamb's avatar
James Lamb committed
95

Laurae's avatar
Laurae committed
96
      # Check if data is a file name and not a matrix
97
      if (identical(class(data), "character") && length(data) == 1L) {
James Lamb's avatar
James Lamb committed
98

99
100
        data <- path.expand(data)

101
        # Data is a filename, create a temporary file with a "lightgbm_" pattern in it
Guolin Ke's avatar
Guolin Ke committed
102
        tmp_filename <- tempfile(pattern = "lightgbm_")
103
        on.exit(unlink(tmp_filename), add = TRUE)
James Lamb's avatar
James Lamb committed
104

105
        # Predict from temporary file
106
107
        .Call(
          LGBM_BoosterPredictForFile_R
108
109
110
111
112
113
          , private$handle
          , data
          , as.integer(header)
          , as.integer(rawscore)
          , as.integer(predleaf)
          , as.integer(predcontrib)
114
          , as.integer(start_iteration)
115
116
          , as.integer(num_iteration)
          , private$params
117
          , tmp_filename
118
        )
James Lamb's avatar
James Lamb committed
119

120
        # Get predictions from file
121
        preds <- utils::read.delim(tmp_filename, header = FALSE, sep = "\t")
Guolin Ke's avatar
Guolin Ke committed
122
        num_row <- nrow(preds)
123
        preds <- as.vector(t(preds))
James Lamb's avatar
James Lamb committed
124

Guolin Ke's avatar
Guolin Ke committed
125
      } else {
James Lamb's avatar
James Lamb committed
126

127
        # Not a file, we need to predict from R object
Guolin Ke's avatar
Guolin Ke committed
128
        num_row <- nrow(data)
James Lamb's avatar
James Lamb committed
129

130
        npred <- 0L
James Lamb's avatar
James Lamb committed
131

132
        # Check number of predictions to do
133
134
        .Call(
          LGBM_BoosterCalcNumPredict_R
135
136
137
138
139
          , private$handle
          , as.integer(num_row)
          , as.integer(rawscore)
          , as.integer(predleaf)
          , as.integer(predcontrib)
140
          , as.integer(start_iteration)
141
          , as.integer(num_iteration)
142
          , npred
143
        )
James Lamb's avatar
James Lamb committed
144

145
146
        # Pre-allocate empty vector
        preds <- numeric(npred)
James Lamb's avatar
James Lamb committed
147

148
        # Check if data is a matrix
Guolin Ke's avatar
Guolin Ke committed
149
        if (is.matrix(data)) {
150
151
          # this if() prevents the memory and computational costs
          # of converting something that is already "double" to "double"
152
153
154
          if (storage.mode(data) != "double") {
            storage.mode(data) <- "double"
          }
155
156
          .Call(
            LGBM_BoosterPredictForMat_R
157
158
159
160
161
162
163
            , private$handle
            , data
            , as.integer(nrow(data))
            , as.integer(ncol(data))
            , as.integer(rawscore)
            , as.integer(predleaf)
            , as.integer(predcontrib)
164
            , as.integer(start_iteration)
165
166
            , as.integer(num_iteration)
            , private$params
167
            , preds
168
          )
James Lamb's avatar
James Lamb committed
169
170

        } else if (methods::is(data, "dgCMatrix")) {
171
          if (length(data@p) > 2147483647L) {
172
173
            stop("Cannot support large CSC matrix")
          }
174
          # Check if data is a dgCMatrix (sparse matrix, column compressed format)
175
176
          .Call(
            LGBM_BoosterPredictForCSC_R
177
178
179
180
181
182
183
184
185
186
            , private$handle
            , data@p
            , data@i
            , data@x
            , length(data@p)
            , length(data@x)
            , nrow(data)
            , as.integer(rawscore)
            , as.integer(predleaf)
            , as.integer(predcontrib)
187
            , as.integer(start_iteration)
188
189
            , as.integer(num_iteration)
            , private$params
190
            , preds
191
          )
James Lamb's avatar
James Lamb committed
192

Guolin Ke's avatar
Guolin Ke committed
193
        } else {
James Lamb's avatar
James Lamb committed
194

195
          stop("predict: cannot predict on data of class ", sQuote(class(data)))
James Lamb's avatar
James Lamb committed
196

197
        }
Guolin Ke's avatar
Guolin Ke committed
198
      }
James Lamb's avatar
James Lamb committed
199

200
      # Check if number of rows is strange (not a multiple of the dataset rows)
201
      if (length(preds) %% num_row != 0L) {
202
203
204
        stop(
          "predict: prediction length "
          , sQuote(length(preds))
205
          , " is not a multiple of nrows(data): "
206
207
          , sQuote(num_row)
        )
Guolin Ke's avatar
Guolin Ke committed
208
      }
James Lamb's avatar
James Lamb committed
209

210
      # Get number of cases per row
Guolin Ke's avatar
Guolin Ke committed
211
      npred_per_case <- length(preds) / num_row
James Lamb's avatar
James Lamb committed
212
213


214
      # Data reshaping
James Lamb's avatar
James Lamb committed
215

216
      if (predleaf | predcontrib) {
James Lamb's avatar
James Lamb committed
217

218
        # Predict leaves only, reshaping is mandatory
219
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
James Lamb's avatar
James Lamb committed
220

221
      } else if (reshape && npred_per_case > 1L) {
James Lamb's avatar
James Lamb committed
222

223
        # Predict with data reshaping
224
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
James Lamb's avatar
James Lamb committed
225

226
      }
James Lamb's avatar
James Lamb committed
227

228
      return(preds)
James Lamb's avatar
James Lamb committed
229

Guolin Ke's avatar
Guolin Ke committed
230
    }
James Lamb's avatar
James Lamb committed
231

232
  ),
233
234
235
236
237
  private = list(
    handle = NULL
    , need_free_handle = FALSE
    , params = ""
  )
Guolin Ke's avatar
Guolin Ke committed
238
)