lgb.Predictor.R 5.84 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
Predictor <- R6Class(
2
  classname = "lgb.Predictor",
3
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
4
  public = list(
5
6
    
    # Finalize will free up the handles
Guolin Ke's avatar
Guolin Ke committed
7
    finalize = function() {
8
9
      
      # Check the need for freeing handle
10
      if (private$need_free_handle && !lgb.is.null.handle(private$handle)) {
11
12
        
        # Freeing up handle
13
        lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
Guolin Ke's avatar
Guolin Ke committed
14
        private$handle <- NULL
15
        
Guolin Ke's avatar
Guolin Ke committed
16
      }
17
      
18
    },
19
20
    
    # Initialize will create a starter model
21
22
23
    initialize = function(modelfile, ...) {
      params <- list(...)
      private$params <- lgb.params2str(params)
24
      # Create new lgb handle
Guolin Ke's avatar
Guolin Ke committed
25
      handle <- lgb.new.handle()
26
27
      
      # Check if handle is a character
28
      if (is.character(modelfile)) {
29
30
        
        # Create handle on it
31
32
        handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R", ret = handle, lgb.c_str(modelfile))
        private$need_free_handle <- TRUE
33
        
34
      } else if (is(modelfile, "lgb.Booster.handle")) {
35
36
        
        # Check if model file is a booster handle already
Guolin Ke's avatar
Guolin Ke committed
37
        handle <- modelfile
38
        private$need_free_handle <- FALSE
39
        
Guolin Ke's avatar
Guolin Ke committed
40
      } else {
41
42
        
        # Model file is unknown
43
        stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
44
        
Guolin Ke's avatar
Guolin Ke committed
45
      }
46
47
      
      # Override class and store it
Guolin Ke's avatar
Guolin Ke committed
48
49
      class(handle) <- "lgb.Booster.handle"
      private$handle <- handle
50
      
Guolin Ke's avatar
Guolin Ke committed
51
    },
52
53
    
    # Get current iteration
Guolin Ke's avatar
Guolin Ke committed
54
    current_iter = function() {
55
      
56
57
      cur_iter <- 0L
      lgb.call("LGBM_BoosterGetCurrentIteration_R",  ret = cur_iter, private$handle)
58
      
Guolin Ke's avatar
Guolin Ke committed
59
    },
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    
    # Predict from data
    predict = function(data,
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
                       header = FALSE,
                       reshape = FALSE) {
      
      # Check if number of iterations is existing - if not, then set it to -1 (use all)
      if (is.null(num_iteration)) {
        num_iteration <- -1
      }
      
      # Set temporary variable
75
      num_row <- 0L
76
77
      
      # Check if data is a file name
78
      if (is.character(data)) {
79
80
        
        # Data is a filename, create a temporary file with a "lightgbm_" pattern in it
Guolin Ke's avatar
Guolin Ke committed
81
        tmp_filename <- tempfile(pattern = "lightgbm_")
82
        on.exit(unlink(tmp_filename), add = TRUE)
83
84
        
        # Predict from temporary file
85
86
        lgb.call("LGBM_BoosterPredictForFile_R", ret = NULL, private$handle, data,
          as.integer(header),
Guolin Ke's avatar
Guolin Ke committed
87
88
89
          as.integer(rawscore),
          as.integer(predleaf),
          as.integer(num_iteration),
90
          private$params,
Guolin Ke's avatar
Guolin Ke committed
91
          lgb.c_str(tmp_filename))
92
93
94
        
        # Get predictions from file
        preds <- read.delim(tmp_filename, header = FALSE, seq = "\t")
Guolin Ke's avatar
Guolin Ke committed
95
        num_row <- nrow(preds)
96
97
        preds <- as.vector(t(preds))
        
Guolin Ke's avatar
Guolin Ke committed
98
      } else {
99
100
        
        # Not a file, we need to predict from R object
Guolin Ke's avatar
Guolin Ke committed
101
        num_row <- nrow(data)
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        npred <- 0L
        
        # Check number of predictions to do
        npred <- lgb.call("LGBM_BoosterCalcNumPredict_R",
                          ret = npred,
                          private$handle,
                          as.integer(num_row),
                          as.integer(rawscore),
                          as.integer(predleaf),
                          as.integer(num_iteration))
        
        # Pre-allocate empty vector
        preds <- numeric(npred)
        
        # Check if data is a matrix
Guolin Ke's avatar
Guolin Ke committed
117
        if (is.matrix(data)) {
118
119
120
121
122
123
124
125
          preds <- lgb.call("LGBM_BoosterPredictForMat_R",
                            ret = preds,
                            private$handle,
                            data,
                            as.integer(nrow(data)),
                            as.integer(ncol(data)),
                            as.integer(rawscore),
                            as.integer(predleaf),
126
127
                            as.integer(num_iteration),
                            private$params)
128
          
129
        } else if (is(data, "dgCMatrix")) {
130
131
132
133
134
135
136
137
138
139
140
141
142
          
          # Check if data is a dgCMatrix (sparse matrix, column compressed format)
          preds <- lgb.call("LGBM_BoosterPredictForCSC_R",
                            ret = preds,
                            private$handle,
                            data@p,
                            data@i,
                            data@x,
                            length(data@p),
                            length(data@x),
                            nrow(data),
                            as.integer(rawscore),
                            as.integer(predleaf),
143
144
                            as.integer(num_iteration),
                            private$params)
145
          
Guolin Ke's avatar
Guolin Ke committed
146
        } else {
147
148
149
          
          # Cannot predict on unknown class
          # to-do: predict from lgb.Dataset
150
          stop("predict: cannot predict on data of class ", sQuote(class(data)))
151
          
152
        }
Guolin Ke's avatar
Guolin Ke committed
153
      }
154
155
      
      # Check if number of rows is strange (not a multiple of the dataset rows)
Guolin Ke's avatar
Guolin Ke committed
156
      if (length(preds) %% num_row != 0) {
157
        stop("predict: prediction length ", sQuote(length(preds))," is not a multiple of nrows(data): ", sQuote(num_row))
Guolin Ke's avatar
Guolin Ke committed
158
      }
159
160
      
      # Get number of cases per row
Guolin Ke's avatar
Guolin Ke committed
161
      npred_per_case <- length(preds) / num_row
162
163
164
165
      
      
      # Data reshaping
      
166
      if (predleaf) {
167
168
        
        # Predict leaves only, reshaping is mandatory
169
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
170
        
171
      } else if (reshape && npred_per_case > 1) {
172
173
        
        # Predict with data reshaping
174
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
175
        
176
      }
177
178
179
180
      
      # Return predictions
      return(preds)
      
Guolin Ke's avatar
Guolin Ke committed
181
    }
182
    
183
  ),
184
  private = list(handle = NULL,
185
186
                 need_free_handle = FALSE,
                 params = "")
Guolin Ke's avatar
Guolin Ke committed
187
)