lightgbm.R 11.4 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
3
#' @name lgb_shared_params
#' @title Shared parameter docs
#' @description Parameter docs shared by \code{lgb.train}, \code{lgb.cv}, and \code{lightgbm}
4
#' @param callbacks List of callback functions that are applied at each iteration.
5
6
7
#' @param data a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
#'             may allow you to pass other types of data like \code{matrix} and then separately supply
#'             \code{label} as a keyword argument.
8
9
10
11
#' @param early_stopping_rounds int. Activates early stopping. Requires at least one validation data
#'                              and one metric. If there's more than one, will check all of them
#'                              except the training data. Returns the model with (best_iter + early_stopping_rounds).
#'                              If early stopping occurs, the model will have 'best_iter' field.
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#' @param eval evaluation function(s). This can be a character vector, function, or list with a mixture of
#'             strings and functions.
#'
#'             \itemize{
#'                 \item{\bold{a. character vector}:
#'                     If you provide a character vector to this argument, it should contain strings with valid
#'                     evaluation metrics.
#'                     See \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#metric}{
#'                     The "metric" section of the documentation}
#'                     for a list of valid metrics.
#'                 }
#'                 \item{\bold{b. function}:
#'                      You can provide a custom evaluation function. This
#'                      should accept the keyword arguments \code{preds} and \code{dtrain} and should return a named
#'                      list with three elements:
#'                      \itemize{
#'                          \item{\code{name}: A string with the name of the metric, used for printing
#'                              and storing results.
#'                          }
#'                          \item{\code{value}: A single number indicating the value of the metric for the
#'                              given predictions and true values
#'                          }
#'                          \item{
#'                              \code{higher_better}: A boolean indicating whether higher values indicate a better fit.
#'                              For example, this would be \code{FALSE} for metrics like MAE or RMSE.
#'                          }
#'                      }
#'                 }
#'                 \item{\bold{c. list}:
#'                     If a list is given, it should only contain character vectors and functions.
#'                     These should follow the requirements from the descriptions above.
#'                 }
#'             }
James Lamb's avatar
James Lamb committed
45
46
47
#' @param eval_freq evaluation output frequency, only effect when verbose > 0
#' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
#' @param nrounds number of training rounds
48
49
50
#' @param obj objective function, can be character or custom objective function. Examples include
#'            \code{regression}, \code{regression_l1}, \code{huber},
#'            \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}
James Lamb's avatar
James Lamb committed
51
52
#' @param params List of parameters
#' @param verbose verbosity for output, if <= 0, also will disable the print of evaluation during training
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#' @section Early Stopping:
#'
#'          "early stopping" refers to stopping the training process if the model's performance on a given
#'          validation set does not improve for several consecutive iterations.
#'
#'          If multiple arguments are given to \code{eval}, their order will be preserved. If you enable
#'          early stopping by setting \code{early_stopping_rounds} in \code{params}, by default all
#'          metrics will be considered for early stopping.
#'
#'          If you want to only consider the first metric for early stopping, pass
#'          \code{first_metric_only = TRUE} in \code{params}. Note that if you also specify \code{metric}
#'          in \code{params}, that metric will be considered the "first" one. If you omit \code{metric},
#'          a default metric will be used based on your choice for the parameter \code{obj} (keyword argument)
#'          or \code{objective} (passed into \code{params}).
67
#' @keywords internal
James Lamb's avatar
James Lamb committed
68
69
70
NULL

#' @name lightgbm
71
#' @title Train a LightGBM model
72
#' @description Simple interface for training a LightGBM model.
James Lamb's avatar
James Lamb committed
73
74
75
76
77
78
#' @inheritParams lgb_shared_params
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param save_name File name to use when writing the trained model to disk. Should end in ".model".
#' @param ... Additional arguments passed to \code{\link{lgb.train}}. For example
#'     \itemize{
79
80
#'        \item{\code{valids}: a list of \code{lgb.Dataset} objects, used for validation}
#'        \item{\code{obj}: objective function, can be character or custom objective function. Examples include
James Lamb's avatar
James Lamb committed
81
82
#'                   \code{regression}, \code{regression_l1}, \code{huber},
#'                    \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}}
83
84
85
86
87
88
89
#'        \item{\code{eval}: evaluation function, can be (a list of) character or custom eval function}
#'        \item{\code{record}: Boolean, TRUE will record iteration message to \code{booster$record_evals}}
#'        \item{\code{colnames}: feature names, if not null, will use this to overwrite the names in dataset}
#'        \item{\code{categorical_feature}: categorical features. This can either be a character vector of feature
#'                            names or an integer vector with the indices of the features (e.g. \code{c(1L, 10L)} to
#'                            say "the first and tenth columns").}
#'        \item{\code{reset_data}: Boolean, setting it to TRUE (not the default value) will transform the booster model
James Lamb's avatar
James Lamb committed
90
#'                          into a predictor model which frees up memory and the original datasets}
91
92
93
#'         \item{\code{boosting}: Boosting type. \code{"gbdt"}, \code{"rf"}, \code{"dart"} or \code{"goss"}.}
#'         \item{\code{num_leaves}: Maximum number of leaves in one tree.}
#'         \item{\code{max_depth}: Limit the max depth for tree model. This is used to deal with
James Lamb's avatar
James Lamb committed
94
#'                          overfit when #data is small. Tree still grow by leaf-wise.}
95
#'          \item{\code{num_threads}: Number of threads for LightGBM. For the best speed, set this to
96
#'                             the number of real CPU cores, not the number of threads (most
James Lamb's avatar
James Lamb committed
97
98
#'                             CPU using hyper-threading to generate 2 threads per CPU core).}
#'     }
99
#' @inheritSection lgb_shared_params Early Stopping
100
#' @return a trained \code{lgb.Booster}
Guolin Ke's avatar
Guolin Ke committed
101
#' @export
102
103
104
105
lightgbm <- function(data,
                     label = NULL,
                     weight = NULL,
                     params = list(),
106
107
                     nrounds = 10L,
                     verbose = 1L,
108
109
110
111
112
113
                     eval_freq = 1L,
                     early_stopping_rounds = NULL,
                     save_name = "lightgbm.model",
                     init_model = NULL,
                     callbacks = list(),
                     ...) {
114

115
  # validate inputs early to avoid unnecessary computation
116
  if (nrounds <= 0L) {
117
118
    stop("nrounds should be greater than zero")
  }
119
120
121
122

  # Set data to a temporary variable
  dtrain <- data

123
  # Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
124
  if (!lgb.is.Dataset(x = dtrain)) {
125
    dtrain <- lgb.Dataset(data = data, label = label, weight = weight)
Guolin Ke's avatar
Guolin Ke committed
126
  }
Guolin Ke's avatar
Guolin Ke committed
127

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
  train_args <- list(
    "params" = params
    , "data" = dtrain
    , "nrounds" = nrounds
    , "verbose" = verbose
    , "eval_freq" = eval_freq
    , "early_stopping_rounds" = early_stopping_rounds
    , "init_model" = init_model
    , "callbacks" = callbacks
  )
  train_args <- append(train_args, list(...))

  if (! "valids" %in% names(train_args)) {
    train_args[["valids"]] <- list()
  }

144
  # Set validation as oneself
145
  if (verbose > 0L) {
146
    train_args[["valids"]][["train"]] <- dtrain
147
  }
148

149
  # Train a model using the regular way
150
151
152
  bst <- do.call(
    what = lgb.train
    , args = train_args
153
  )
154

155
  # Store model under a specific name
156
  bst$save_model(filename = save_name)
157

158
  return(bst)
Guolin Ke's avatar
Guolin Ke committed
159
160
}

161
162
163
164
165
#' @name agaricus.train
#' @title Training part from Mushroom Data Set
#' @description This data set is originally from the Mushroom data set,
#'              UCI Machine Learning Repository.
#'              This data set includes the following fields:
166
#'
167
168
169
170
#'               \itemize{
#'                   \item{\code{label}: the label for each record}
#'                   \item{\code{data}: a sparse Matrix of \code{dgCMatrix} class, with 126 columns.}
#'                }
Guolin Ke's avatar
Guolin Ke committed
171
172
173
#'
#' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
174
175
176
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
Guolin Ke's avatar
Guolin Ke committed
177
#' School of Information and Computer Science.
178
#'
Guolin Ke's avatar
Guolin Ke committed
179
180
181
#' @docType data
#' @keywords datasets
#' @usage data(agaricus.train)
182
#' @format A list containing a label vector, and a dgCMatrix object with 6513
Guolin Ke's avatar
Guolin Ke committed
183
184
185
#' rows and 127 variables
NULL

186
187
188
189
190
191
192
193
194
195
#' @name agaricus.test
#' @title Test part from Mushroom Data Set
#' @description This data set is originally from the Mushroom data set,
#'              UCI Machine Learning Repository.
#'              This data set includes the following fields:
#'
#'              \itemize{
#'                  \item{\code{label}: the label for each record}
#'                  \item{\code{data}: a sparse Matrix of \code{dgCMatrix} class, with 126 columns.}
#'              }
Guolin Ke's avatar
Guolin Ke committed
196
197
#' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
198
199
200
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
Guolin Ke's avatar
Guolin Ke committed
201
#' School of Information and Computer Science.
202
#'
Guolin Ke's avatar
Guolin Ke committed
203
204
205
#' @docType data
#' @keywords datasets
#' @usage data(agaricus.test)
206
#' @format A list containing a label vector, and a dgCMatrix object with 1611
Guolin Ke's avatar
Guolin Ke committed
207
208
209
#' rows and 126 variables
NULL

210
211
212
213
#' @name bank
#' @title Bank Marketing Data Set
#' @description This data set is originally from the Bank Marketing data set,
#'              UCI Machine Learning Repository.
214
#'
215
216
#'              It contains only the following: bank.csv with 10% of the examples and 17 inputs,
#'              randomly selected from 3 (older version of this dataset with less inputs).
217
218
219
#'
#' @references
#' http://archive.ics.uci.edu/ml/datasets/Bank+Marketing
220
#'
221
222
223
224
225
226
227
228
229
#' S. Moro, P. Cortez and P. Rita. (2014)
#' A Data-Driven Approach to Predict the Success of Bank Telemarketing. Decision Support Systems
#'
#' @docType data
#' @keywords datasets
#' @usage data(bank)
#' @format A data.table with 4521 rows and 17 variables
NULL

Guolin Ke's avatar
Guolin Ke committed
230
# Various imports
Guolin Ke's avatar
Guolin Ke committed
231
#' @import methods
232
#' @importFrom Matrix Matrix
Guolin Ke's avatar
Guolin Ke committed
233
#' @importFrom R6 R6Class
James Lamb's avatar
James Lamb committed
234
#' @useDynLib lib_lightgbm , .registration = TRUE
235
NULL
James Lamb's avatar
James Lamb committed
236
237
238
239
240
241
242

# Suppress false positive warnings from R CMD CHECK about
# "unrecognized global variable"
globalVariables(c(
    "."
    , ".N"
    , ".SD"
243
    , "abs_contribution"
244
    , "bar_color"
James Lamb's avatar
James Lamb committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    , "Contribution"
    , "Cover"
    , "Feature"
    , "Frequency"
    , "Gain"
    , "internal_count"
    , "internal_value"
    , "leaf_index"
    , "leaf_parent"
    , "leaf_value"
    , "node_parent"
    , "split_feature"
    , "split_gain"
    , "split_index"
    , "tree_index"
))