lightgbm.R 13 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
3
#' @name lgb_shared_params
#' @title Shared parameter docs
#' @description Parameter docs shared by \code{lgb.train}, \code{lgb.cv}, and \code{lightgbm}
4
#' @param callbacks List of callback functions that are applied at each iteration.
5
6
7
#' @param data a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
#'             may allow you to pass other types of data like \code{matrix} and then separately supply
#'             \code{label} as a keyword argument.
8
9
10
11
12
#' @param early_stopping_rounds int. Activates early stopping. When this parameter is non-null,
#'                              training will stop if the evaluation of any metric on any validation set
#'                              fails to improve for \code{early_stopping_rounds} consecutive boosting rounds.
#'                              If training stops early, the returned model will have attribute \code{best_iter}
#'                              set to the iteration number of the best iteration.
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#' @param eval evaluation function(s). This can be a character vector, function, or list with a mixture of
#'             strings and functions.
#'
#'             \itemize{
#'                 \item{\bold{a. character vector}:
#'                     If you provide a character vector to this argument, it should contain strings with valid
#'                     evaluation metrics.
#'                     See \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#metric}{
#'                     The "metric" section of the documentation}
#'                     for a list of valid metrics.
#'                 }
#'                 \item{\bold{b. function}:
#'                      You can provide a custom evaluation function. This
#'                      should accept the keyword arguments \code{preds} and \code{dtrain} and should return a named
#'                      list with three elements:
#'                      \itemize{
#'                          \item{\code{name}: A string with the name of the metric, used for printing
#'                              and storing results.
#'                          }
#'                          \item{\code{value}: A single number indicating the value of the metric for the
#'                              given predictions and true values
#'                          }
#'                          \item{
#'                              \code{higher_better}: A boolean indicating whether higher values indicate a better fit.
#'                              For example, this would be \code{FALSE} for metrics like MAE or RMSE.
#'                          }
#'                      }
#'                 }
#'                 \item{\bold{c. list}:
#'                     If a list is given, it should only contain character vectors and functions.
#'                     These should follow the requirements from the descriptions above.
#'                 }
#'             }
James Lamb's avatar
James Lamb committed
46
47
48
#' @param eval_freq evaluation output frequency, only effect when verbose > 0
#' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
#' @param nrounds number of training rounds
49
50
51
#' @param obj objective function, can be character or custom objective function. Examples include
#'            \code{regression}, \code{regression_l1}, \code{huber},
#'            \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}
52
53
#' @param params a list of parameters. See \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html}{
#'               the "Parameters" section of the documentation} for a list of parameters and valid values.
James Lamb's avatar
James Lamb committed
54
#' @param verbose verbosity for output, if <= 0, also will disable the print of evaluation during training
55
#' @param serializable whether to make the resulting objects serializable through functions such as
56
#'                     \code{save} or \code{saveRDS} (see section "Model serialization").
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#' @section Early Stopping:
#'
#'          "early stopping" refers to stopping the training process if the model's performance on a given
#'          validation set does not improve for several consecutive iterations.
#'
#'          If multiple arguments are given to \code{eval}, their order will be preserved. If you enable
#'          early stopping by setting \code{early_stopping_rounds} in \code{params}, by default all
#'          metrics will be considered for early stopping.
#'
#'          If you want to only consider the first metric for early stopping, pass
#'          \code{first_metric_only = TRUE} in \code{params}. Note that if you also specify \code{metric}
#'          in \code{params}, that metric will be considered the "first" one. If you omit \code{metric},
#'          a default metric will be used based on your choice for the parameter \code{obj} (keyword argument)
#'          or \code{objective} (passed into \code{params}).
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#' @section Model serialization:
#'
#'          LightGBM model objects can be serialized and de-serialized through functions such as \code{save}
#'          or \code{saveRDS}, but similarly to libraries such as 'xgboost', serialization works a bit differently
#'          from typical R objects. In order to make models serializable in R, a copy of the underlying C++ object
#'          as serialized raw bytes is produced and stored in the R model object, and when this R object is
#'          de-serialized, the underlying C++ model object gets reconstructed from these raw bytes, but will only
#'          do so once some function that uses it is called, such as \code{predict}. In order to forcibly
#'          reconstruct the C++ object after deserialization (e.g. after calling \code{readRDS} or similar), one
#'          can use the function \link{lgb.restore_handle} (for example, if one makes predictions in parallel or in
#'          forked processes, it will be faster to restore the handle beforehand).
#'
#'          Producing and keeping these raw bytes however uses extra memory, and if they are not required,
#'          it is possible to avoid producing them by passing `serializable=FALSE`. In such cases, these raw
#'          bytes can be added to the model on demand through function \link{lgb.make_serializable}.
86
#' @keywords internal
James Lamb's avatar
James Lamb committed
87
88
89
NULL

#' @name lightgbm
90
#' @title Train a LightGBM model
91
#' @description Simple interface for training a LightGBM model.
James Lamb's avatar
James Lamb committed
92
93
#' @inheritParams lgb_shared_params
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
94
95
#' @param weights Sample / observation weights for rows in the input data. If \code{NULL}, will assume that all
#'                observations / rows have the same importance / weight.
96
97
#' @param objective Optimization objective (e.g. `"regression"`, `"binary"`, etc.).
#'                  For a list of accepted objectives, see
98
99
#'                  \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#objective}{
#'                  the "objective" item of the "Parameters" section of the documentation}.
100
#' @param init_score initial score is the base prediction lightgbm will boost from
James Lamb's avatar
James Lamb committed
101
102
#' @param ... Additional arguments passed to \code{\link{lgb.train}}. For example
#'     \itemize{
103
104
#'        \item{\code{valids}: a list of \code{lgb.Dataset} objects, used for validation}
#'        \item{\code{obj}: objective function, can be character or custom objective function. Examples include
James Lamb's avatar
James Lamb committed
105
106
#'                   \code{regression}, \code{regression_l1}, \code{huber},
#'                    \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}}
107
108
109
110
111
112
113
#'        \item{\code{eval}: evaluation function, can be (a list of) character or custom eval function}
#'        \item{\code{record}: Boolean, TRUE will record iteration message to \code{booster$record_evals}}
#'        \item{\code{colnames}: feature names, if not null, will use this to overwrite the names in dataset}
#'        \item{\code{categorical_feature}: categorical features. This can either be a character vector of feature
#'                            names or an integer vector with the indices of the features (e.g. \code{c(1L, 10L)} to
#'                            say "the first and tenth columns").}
#'        \item{\code{reset_data}: Boolean, setting it to TRUE (not the default value) will transform the booster model
James Lamb's avatar
James Lamb committed
114
115
#'                          into a predictor model which frees up memory and the original datasets}
#'     }
116
#' @inheritSection lgb_shared_params Early Stopping
117
#' @return a trained \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
118
#' @export
119
120
lightgbm <- function(data,
                     label = NULL,
121
                     weights = NULL,
122
                     params = list(),
123
                     nrounds = 100L,
124
                     verbose = 1L,
125
126
127
128
                     eval_freq = 1L,
                     early_stopping_rounds = NULL,
                     init_model = NULL,
                     callbacks = list(),
129
                     serializable = TRUE,
130
131
                     objective = "regression",
                     init_score = NULL,
132
                     ...) {
133

134
  # validate inputs early to avoid unnecessary computation
135
  if (nrounds <= 0L) {
136
137
    stop("nrounds should be greater than zero")
  }
138
139
140
141

  # Set data to a temporary variable
  dtrain <- data

142
  # Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
143
  if (!lgb.is.Dataset(x = dtrain)) {
144
    dtrain <- lgb.Dataset(data = data, label = label, weight = weights, init_score = init_score)
Guolin Ke's avatar
Guolin Ke committed
145
  }
Guolin Ke's avatar
Guolin Ke committed
146

147
148
149
150
  train_args <- list(
    "params" = params
    , "data" = dtrain
    , "nrounds" = nrounds
151
    , "obj" = objective
152
153
154
155
156
    , "verbose" = verbose
    , "eval_freq" = eval_freq
    , "early_stopping_rounds" = early_stopping_rounds
    , "init_model" = init_model
    , "callbacks" = callbacks
157
    , "serializable" = serializable
158
159
160
161
162
163
164
  )
  train_args <- append(train_args, list(...))

  if (! "valids" %in% names(train_args)) {
    train_args[["valids"]] <- list()
  }

165
  # Set validation as oneself
166
  if (verbose > 0L) {
167
    train_args[["valids"]][["train"]] <- dtrain
168
  }
169

170
  # Train a model using the regular way
171
172
173
  bst <- do.call(
    what = lgb.train
    , args = train_args
174
  )
175

176
  return(bst)
Guolin Ke's avatar
Guolin Ke committed
177
178
}

179
180
181
182
183
#' @name agaricus.train
#' @title Training part from Mushroom Data Set
#' @description This data set is originally from the Mushroom data set,
#'              UCI Machine Learning Repository.
#'              This data set includes the following fields:
184
#'
185
186
187
188
#'               \itemize{
#'                   \item{\code{label}: the label for each record}
#'                   \item{\code{data}: a sparse Matrix of \code{dgCMatrix} class, with 126 columns.}
#'                }
Guolin Ke's avatar
Guolin Ke committed
189
190
191
#'
#' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
192
193
194
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
Guolin Ke's avatar
Guolin Ke committed
195
#' School of Information and Computer Science.
196
#'
Guolin Ke's avatar
Guolin Ke committed
197
198
199
#' @docType data
#' @keywords datasets
#' @usage data(agaricus.train)
200
#' @format A list containing a label vector, and a dgCMatrix object with 6513
Guolin Ke's avatar
Guolin Ke committed
201
202
203
#' rows and 127 variables
NULL

204
205
206
207
208
209
210
211
212
213
#' @name agaricus.test
#' @title Test part from Mushroom Data Set
#' @description This data set is originally from the Mushroom data set,
#'              UCI Machine Learning Repository.
#'              This data set includes the following fields:
#'
#'              \itemize{
#'                  \item{\code{label}: the label for each record}
#'                  \item{\code{data}: a sparse Matrix of \code{dgCMatrix} class, with 126 columns.}
#'              }
Guolin Ke's avatar
Guolin Ke committed
214
215
#' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
216
217
218
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
Guolin Ke's avatar
Guolin Ke committed
219
#' School of Information and Computer Science.
220
#'
Guolin Ke's avatar
Guolin Ke committed
221
222
223
#' @docType data
#' @keywords datasets
#' @usage data(agaricus.test)
224
#' @format A list containing a label vector, and a dgCMatrix object with 1611
Guolin Ke's avatar
Guolin Ke committed
225
226
227
#' rows and 126 variables
NULL

228
229
230
231
#' @name bank
#' @title Bank Marketing Data Set
#' @description This data set is originally from the Bank Marketing data set,
#'              UCI Machine Learning Repository.
232
#'
233
234
#'              It contains only the following: bank.csv with 10% of the examples and 17 inputs,
#'              randomly selected from 3 (older version of this dataset with less inputs).
235
236
237
#'
#' @references
#' http://archive.ics.uci.edu/ml/datasets/Bank+Marketing
238
#'
239
240
241
242
243
244
245
246
247
#' S. Moro, P. Cortez and P. Rita. (2014)
#' A Data-Driven Approach to Predict the Success of Bank Telemarketing. Decision Support Systems
#'
#' @docType data
#' @keywords datasets
#' @usage data(bank)
#' @format A data.table with 4521 rows and 17 variables
NULL

Guolin Ke's avatar
Guolin Ke committed
248
# Various imports
Guolin Ke's avatar
Guolin Ke committed
249
#' @import methods
250
#' @importFrom Matrix Matrix
Guolin Ke's avatar
Guolin Ke committed
251
#' @importFrom R6 R6Class
James Lamb's avatar
James Lamb committed
252
#' @useDynLib lib_lightgbm , .registration = TRUE
253
NULL
James Lamb's avatar
James Lamb committed
254
255
256
257
258
259
260

# Suppress false positive warnings from R CMD CHECK about
# "unrecognized global variable"
globalVariables(c(
    "."
    , ".N"
    , ".SD"
261
    , "abs_contribution"
262
    , "bar_color"
James Lamb's avatar
James Lamb committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    , "Contribution"
    , "Cover"
    , "Feature"
    , "Frequency"
    , "Gain"
    , "internal_count"
    , "internal_value"
    , "leaf_index"
    , "leaf_parent"
    , "leaf_value"
    , "node_parent"
    , "split_feature"
    , "split_gain"
    , "split_index"
    , "tree_index"
))