lgb.Predictor.R 5.82 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom methods is
James Lamb's avatar
James Lamb committed
2
3
4
#' @importFrom R6 R6Class
Predictor <- R6::R6Class(

5
  classname = "lgb.Predictor",
6
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
7
  public = list(
James Lamb's avatar
James Lamb committed
8

9
    # Finalize will free up the handles
Guolin Ke's avatar
Guolin Ke committed
10
    finalize = function() {
James Lamb's avatar
James Lamb committed
11

12
      # Check the need for freeing handle
13
      if (private$need_free_handle && !lgb.is.null.handle(private$handle)) {
James Lamb's avatar
James Lamb committed
14

15
        # Freeing up handle
16
17
18
19
20
        lgb.call(
          "LGBM_BoosterFree_R"
          , ret = NULL
          , private$handle
        )
Guolin Ke's avatar
Guolin Ke committed
21
        private$handle <- NULL
James Lamb's avatar
James Lamb committed
22

Guolin Ke's avatar
Guolin Ke committed
23
      }
James Lamb's avatar
James Lamb committed
24

25
    },
James Lamb's avatar
James Lamb committed
26

27
    # Initialize will create a starter model
28
29
30
    initialize = function(modelfile, ...) {
      params <- list(...)
      private$params <- lgb.params2str(params)
31
      # Create new lgb handle
Guolin Ke's avatar
Guolin Ke committed
32
      handle <- 0.0
James Lamb's avatar
James Lamb committed
33

34
      # Check if handle is a character
35
      if (is.character(modelfile)) {
James Lamb's avatar
James Lamb committed
36

37
        # Create handle on it
38
39
40
41
42
        handle <- lgb.call(
          "LGBM_BoosterCreateFromModelfile_R"
          , ret = handle
          , lgb.c_str(modelfile)
        )
43
        private$need_free_handle <- TRUE
James Lamb's avatar
James Lamb committed
44
45
46

      } else if (methods::is(modelfile, "lgb.Booster.handle")) {

47
        # Check if model file is a booster handle already
Guolin Ke's avatar
Guolin Ke committed
48
        handle <- modelfile
49
        private$need_free_handle <- FALSE
James Lamb's avatar
James Lamb committed
50

Guolin Ke's avatar
Guolin Ke committed
51
      } else {
James Lamb's avatar
James Lamb committed
52

53
        # Model file is unknown
54
        stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
James Lamb's avatar
James Lamb committed
55

Guolin Ke's avatar
Guolin Ke committed
56
      }
James Lamb's avatar
James Lamb committed
57

58
      # Override class and store it
Guolin Ke's avatar
Guolin Ke committed
59
60
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
James Lamb's avatar
James Lamb committed
61

Guolin Ke's avatar
Guolin Ke committed
62
    },
James Lamb's avatar
James Lamb committed
63

64
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
65
    current_iter = function() {
James Lamb's avatar
James Lamb committed
66

67
      cur_iter <- 0L
68
69
70
71
72
      lgb.call(
        "LGBM_BoosterGetCurrentIteration_R"
        , ret = cur_iter
        , private$handle
      )
James Lamb's avatar
James Lamb committed
73

Guolin Ke's avatar
Guolin Ke committed
74
    },
James Lamb's avatar
James Lamb committed
75

76
77
78
79
80
    # Predict from data
    predict = function(data,
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
81
                       predcontrib = FALSE,
82
83
                       header = FALSE,
                       reshape = FALSE) {
James Lamb's avatar
James Lamb committed
84

85
86
87
88
      # Check if number of iterations is existing - if not, then set it to -1 (use all)
      if (is.null(num_iteration)) {
        num_iteration <- -1
      }
James Lamb's avatar
James Lamb committed
89

90
      # Set temporary variable
91
      num_row <- 0L
James Lamb's avatar
James Lamb committed
92

Laurae's avatar
Laurae committed
93
94
      # Check if data is a file name and not a matrix
      if (identical(class(data), "character") && length(data) == 1) {
James Lamb's avatar
James Lamb committed
95

96
        # Data is a filename, create a temporary file with a "lightgbm_" pattern in it
Guolin Ke's avatar
Guolin Ke committed
97
        tmp_filename <- tempfile(pattern = "lightgbm_")
98
        on.exit(unlink(tmp_filename), add = TRUE)
James Lamb's avatar
James Lamb committed
99

100
        # Predict from temporary file
101
102
103
104
105
106
107
108
109
110
111
112
113
        lgb.call(
          "LGBM_BoosterPredictForFile_R"
          , ret = NULL
          , private$handle
          , data
          , as.integer(header)
          , as.integer(rawscore)
          , as.integer(predleaf)
          , as.integer(predcontrib)
          , as.integer(num_iteration)
          , private$params
          , lgb.c_str(tmp_filename)
        )
James Lamb's avatar
James Lamb committed
114

115
        # Get predictions from file
Laurae's avatar
Laurae committed
116
        preds <- read.delim(tmp_filename, header = FALSE, sep = "\t")
Guolin Ke's avatar
Guolin Ke committed
117
        num_row <- nrow(preds)
118
        preds <- as.vector(t(preds))
James Lamb's avatar
James Lamb committed
119

Guolin Ke's avatar
Guolin Ke committed
120
      } else {
James Lamb's avatar
James Lamb committed
121

122
        # Not a file, we need to predict from R object
Guolin Ke's avatar
Guolin Ke committed
123
        num_row <- nrow(data)
James Lamb's avatar
James Lamb committed
124

125
        npred <- 0L
James Lamb's avatar
James Lamb committed
126

127
        # Check number of predictions to do
128
129
130
131
132
133
134
135
136
137
        npred <- lgb.call(
          "LGBM_BoosterCalcNumPredict_R"
          , ret = npred
          , private$handle
          , as.integer(num_row)
          , as.integer(rawscore)
          , as.integer(predleaf)
          , as.integer(predcontrib)
          , as.integer(num_iteration)
        )
James Lamb's avatar
James Lamb committed
138

139
140
        # Pre-allocate empty vector
        preds <- numeric(npred)
James Lamb's avatar
James Lamb committed
141

142
        # Check if data is a matrix
Guolin Ke's avatar
Guolin Ke committed
143
        if (is.matrix(data)) {
144
145
146
147
148
149
150
151
152
153
154
155
156
          preds <- lgb.call(
            "LGBM_BoosterPredictForMat_R"
            , ret = preds
            , private$handle
            , data
            , as.integer(nrow(data))
            , as.integer(ncol(data))
            , as.integer(rawscore)
            , as.integer(predleaf)
            , as.integer(predcontrib)
            , as.integer(num_iteration)
            , private$params
          )
James Lamb's avatar
James Lamb committed
157
158

        } else if (methods::is(data, "dgCMatrix")) {
159
160
161
          if (length(data@p) > 2147483647) {
            stop("Cannot support large CSC matrix")
          }
162
          # Check if data is a dgCMatrix (sparse matrix, column compressed format)
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
          preds <- lgb.call(
            "LGBM_BoosterPredictForCSC_R"
            , ret = preds
            , private$handle
            , data@p
            , data@i
            , data@x
            , length(data@p)
            , length(data@x)
            , nrow(data)
            , as.integer(rawscore)
            , as.integer(predleaf)
            , as.integer(predcontrib)
            , as.integer(num_iteration)
            , private$params
          )
James Lamb's avatar
James Lamb committed
179

Guolin Ke's avatar
Guolin Ke committed
180
        } else {
James Lamb's avatar
James Lamb committed
181

182
183
          # Cannot predict on unknown class
          # to-do: predict from lgb.Dataset
184
          stop("predict: cannot predict on data of class ", sQuote(class(data)))
James Lamb's avatar
James Lamb committed
185

186
        }
Guolin Ke's avatar
Guolin Ke committed
187
      }
James Lamb's avatar
James Lamb committed
188

189
      # Check if number of rows is strange (not a multiple of the dataset rows)
Guolin Ke's avatar
Guolin Ke committed
190
      if (length(preds) %% num_row != 0) {
191
192
193
194
195
196
        stop(
          "predict: prediction length "
          , sQuote(length(preds))
          ," is not a multiple of nrows(data): "
          , sQuote(num_row)
        )
Guolin Ke's avatar
Guolin Ke committed
197
      }
James Lamb's avatar
James Lamb committed
198

199
      # Get number of cases per row
Guolin Ke's avatar
Guolin Ke committed
200
      npred_per_case <- length(preds) / num_row
James Lamb's avatar
James Lamb committed
201
202


203
      # Data reshaping
James Lamb's avatar
James Lamb committed
204

205
      if (predleaf | predcontrib) {
James Lamb's avatar
James Lamb committed
206

207
        # Predict leaves only, reshaping is mandatory
208
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
James Lamb's avatar
James Lamb committed
209

210
      } else if (reshape && npred_per_case > 1) {
James Lamb's avatar
James Lamb committed
211

212
        # Predict with data reshaping
213
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
James Lamb's avatar
James Lamb committed
214

215
      }
James Lamb's avatar
James Lamb committed
216

217
218
      # Return predictions
      return(preds)
James Lamb's avatar
James Lamb committed
219

Guolin Ke's avatar
Guolin Ke committed
220
    }
James Lamb's avatar
James Lamb committed
221

222
  ),
223
224
225
226
227
  private = list(
    handle = NULL
    , need_free_handle = FALSE
    , params = ""
  )
Guolin Ke's avatar
Guolin Ke committed
228
)