lgb.Predictor.R 5.98 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
#' @importFrom methods is
James Lamb's avatar
James Lamb committed
2
3
4
#' @importFrom R6 R6Class
Predictor <- R6::R6Class(

5
  classname = "lgb.Predictor",
6
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
7
  public = list(
James Lamb's avatar
James Lamb committed
8

9
    # Finalize will free up the handles
Guolin Ke's avatar
Guolin Ke committed
10
    finalize = function() {
James Lamb's avatar
James Lamb committed
11

12
      # Check the need for freeing handle
13
      if (private$need_free_handle && !lgb.is.null.handle(private$handle)) {
James Lamb's avatar
James Lamb committed
14

15
        # Freeing up handle
16
        lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
Guolin Ke's avatar
Guolin Ke committed
17
        private$handle <- NULL
James Lamb's avatar
James Lamb committed
18

Guolin Ke's avatar
Guolin Ke committed
19
      }
James Lamb's avatar
James Lamb committed
20

21
    },
James Lamb's avatar
James Lamb committed
22

23
    # Initialize will create a starter model
24
25
26
    initialize = function(modelfile, ...) {
      params <- list(...)
      private$params <- lgb.params2str(params)
27
      # Create new lgb handle
Guolin Ke's avatar
Guolin Ke committed
28
      handle <- 0.0
James Lamb's avatar
James Lamb committed
29

30
      # Check if handle is a character
31
      if (is.character(modelfile)) {
James Lamb's avatar
James Lamb committed
32

33
        # Create handle on it
34
35
        handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R", ret = handle, lgb.c_str(modelfile))
        private$need_free_handle <- TRUE
James Lamb's avatar
James Lamb committed
36
37
38

      } else if (methods::is(modelfile, "lgb.Booster.handle")) {

39
        # Check if model file is a booster handle already
Guolin Ke's avatar
Guolin Ke committed
40
        handle <- modelfile
41
        private$need_free_handle <- FALSE
James Lamb's avatar
James Lamb committed
42

Guolin Ke's avatar
Guolin Ke committed
43
      } else {
James Lamb's avatar
James Lamb committed
44

45
        # Model file is unknown
46
        stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
James Lamb's avatar
James Lamb committed
47

Guolin Ke's avatar
Guolin Ke committed
48
      }
James Lamb's avatar
James Lamb committed
49

50
      # Override class and store it
Guolin Ke's avatar
Guolin Ke committed
51
52
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
James Lamb's avatar
James Lamb committed
53

Guolin Ke's avatar
Guolin Ke committed
54
    },
James Lamb's avatar
James Lamb committed
55

56
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
57
    current_iter = function() {
James Lamb's avatar
James Lamb committed
58

59
60
      cur_iter <- 0L
      lgb.call("LGBM_BoosterGetCurrentIteration_R",  ret = cur_iter, private$handle)
James Lamb's avatar
James Lamb committed
61

Guolin Ke's avatar
Guolin Ke committed
62
    },
James Lamb's avatar
James Lamb committed
63

64
65
66
67
68
    # Predict from data
    predict = function(data,
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
69
                       predcontrib = FALSE,
70
71
                       header = FALSE,
                       reshape = FALSE) {
James Lamb's avatar
James Lamb committed
72

73
74
75
76
      # Check if number of iterations is existing - if not, then set it to -1 (use all)
      if (is.null(num_iteration)) {
        num_iteration <- -1
      }
James Lamb's avatar
James Lamb committed
77

78
      # Set temporary variable
79
      num_row <- 0L
James Lamb's avatar
James Lamb committed
80

Laurae's avatar
Laurae committed
81
82
      # Check if data is a file name and not a matrix
      if (identical(class(data), "character") && length(data) == 1) {
James Lamb's avatar
James Lamb committed
83

84
        # Data is a filename, create a temporary file with a "lightgbm_" pattern in it
Guolin Ke's avatar
Guolin Ke committed
85
        tmp_filename <- tempfile(pattern = "lightgbm_")
86
        on.exit(unlink(tmp_filename), add = TRUE)
James Lamb's avatar
James Lamb committed
87

88
        # Predict from temporary file
89
90
        lgb.call("LGBM_BoosterPredictForFile_R", ret = NULL, private$handle, data,
          as.integer(header),
Guolin Ke's avatar
Guolin Ke committed
91
92
          as.integer(rawscore),
          as.integer(predleaf),
93
          as.integer(predcontrib),
Guolin Ke's avatar
Guolin Ke committed
94
          as.integer(num_iteration),
95
          private$params,
Guolin Ke's avatar
Guolin Ke committed
96
          lgb.c_str(tmp_filename))
James Lamb's avatar
James Lamb committed
97

98
        # Get predictions from file
Laurae's avatar
Laurae committed
99
        preds <- read.delim(tmp_filename, header = FALSE, sep = "\t")
Guolin Ke's avatar
Guolin Ke committed
100
        num_row <- nrow(preds)
101
        preds <- as.vector(t(preds))
James Lamb's avatar
James Lamb committed
102

Guolin Ke's avatar
Guolin Ke committed
103
      } else {
James Lamb's avatar
James Lamb committed
104

105
        # Not a file, we need to predict from R object
Guolin Ke's avatar
Guolin Ke committed
106
        num_row <- nrow(data)
James Lamb's avatar
James Lamb committed
107

108
        npred <- 0L
James Lamb's avatar
James Lamb committed
109

110
111
112
113
114
115
116
        # Check number of predictions to do
        npred <- lgb.call("LGBM_BoosterCalcNumPredict_R",
                          ret = npred,
                          private$handle,
                          as.integer(num_row),
                          as.integer(rawscore),
                          as.integer(predleaf),
117
                          as.integer(predcontrib),
118
                          as.integer(num_iteration))
James Lamb's avatar
James Lamb committed
119

120
121
        # Pre-allocate empty vector
        preds <- numeric(npred)
James Lamb's avatar
James Lamb committed
122

123
        # Check if data is a matrix
Guolin Ke's avatar
Guolin Ke committed
124
        if (is.matrix(data)) {
125
126
127
128
129
130
131
132
          preds <- lgb.call("LGBM_BoosterPredictForMat_R",
                            ret = preds,
                            private$handle,
                            data,
                            as.integer(nrow(data)),
                            as.integer(ncol(data)),
                            as.integer(rawscore),
                            as.integer(predleaf),
133
                            as.integer(predcontrib),
134
135
                            as.integer(num_iteration),
                            private$params)
James Lamb's avatar
James Lamb committed
136
137

        } else if (methods::is(data, "dgCMatrix")) {
138
139
140
          if (length(data@p) > 2147483647) {
            stop("Cannot support large CSC matrix")
          }
141
142
143
144
145
146
147
148
149
150
151
152
          # Check if data is a dgCMatrix (sparse matrix, column compressed format)
          preds <- lgb.call("LGBM_BoosterPredictForCSC_R",
                            ret = preds,
                            private$handle,
                            data@p,
                            data@i,
                            data@x,
                            length(data@p),
                            length(data@x),
                            nrow(data),
                            as.integer(rawscore),
                            as.integer(predleaf),
153
                            as.integer(predcontrib),
154
155
                            as.integer(num_iteration),
                            private$params)
James Lamb's avatar
James Lamb committed
156

Guolin Ke's avatar
Guolin Ke committed
157
        } else {
James Lamb's avatar
James Lamb committed
158

159
160
          # Cannot predict on unknown class
          # to-do: predict from lgb.Dataset
161
          stop("predict: cannot predict on data of class ", sQuote(class(data)))
James Lamb's avatar
James Lamb committed
162

163
        }
Guolin Ke's avatar
Guolin Ke committed
164
      }
James Lamb's avatar
James Lamb committed
165

166
      # Check if number of rows is strange (not a multiple of the dataset rows)
Guolin Ke's avatar
Guolin Ke committed
167
      if (length(preds) %% num_row != 0) {
168
        stop("predict: prediction length ", sQuote(length(preds))," is not a multiple of nrows(data): ", sQuote(num_row))
Guolin Ke's avatar
Guolin Ke committed
169
      }
James Lamb's avatar
James Lamb committed
170

171
      # Get number of cases per row
Guolin Ke's avatar
Guolin Ke committed
172
      npred_per_case <- length(preds) / num_row
James Lamb's avatar
James Lamb committed
173
174


175
      # Data reshaping
James Lamb's avatar
James Lamb committed
176

177
      if (predleaf | predcontrib) {
James Lamb's avatar
James Lamb committed
178

179
        # Predict leaves only, reshaping is mandatory
180
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
James Lamb's avatar
James Lamb committed
181

182
      } else if (reshape && npred_per_case > 1) {
James Lamb's avatar
James Lamb committed
183

184
        # Predict with data reshaping
185
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
James Lamb's avatar
James Lamb committed
186

187
      }
James Lamb's avatar
James Lamb committed
188

189
190
      # Return predictions
      return(preds)
James Lamb's avatar
James Lamb committed
191

Guolin Ke's avatar
Guolin Ke committed
192
    }
James Lamb's avatar
James Lamb committed
193

194
  ),
195
  private = list(handle = NULL,
196
197
                 need_free_handle = FALSE,
                 params = "")
Guolin Ke's avatar
Guolin Ke committed
198
)