lightgbm.R 7.2 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
3
4
5
6
#' @name lgb_shared_params
#' @title Shared parameter docs
#' @description Parameter docs shared by \code{lgb.train}, \code{lgb.cv}, and \code{lightgbm}
#' @param callbacks list of callback functions
#'        List of callback functions that are applied at each iteration.
#' @param data a \code{lgb.Dataset} object, used for training
7
8
9
10
#' @param early_stopping_rounds int. Activates early stopping. Requires at least one validation data
#'                              and one metric. If there's more than one, will check all of them
#'                              except the training data. Returns the model with (best_iter + early_stopping_rounds).
#'                              If early stopping occurs, the model will have 'best_iter' field.
James Lamb's avatar
James Lamb committed
11
12
13
14
15
16
17
18
19
20
#' @param eval_freq evaluation output frequency, only effect when verbose > 0
#' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
#' @param nrounds number of training rounds
#' @param params List of parameters
#' @param verbose verbosity for output, if <= 0, also will disable the print of evaluation during training
NULL


#' @title Train a LightGBM model
#' @name lightgbm
21
#' @description Simple interface for training a LightGBM model.
James Lamb's avatar
James Lamb committed
22
23
24
25
26
27
28
#' @inheritParams lgb_shared_params
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param save_name File name to use when writing the trained model to disk. Should end in ".model".
#' @param ... Additional arguments passed to \code{\link{lgb.train}}. For example
#'     \itemize{
#'        \item{valids}{a list of \code{lgb.Dataset} objects, used for validation}
29
#'        \item{obj}{objective function, can be character or custom objective function. Examples include
James Lamb's avatar
James Lamb committed
30
31
32
33
34
35
#'                   \code{regression}, \code{regression_l1}, \code{huber},
#'                    \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}}
#'        \item{eval}{evaluation function, can be (a list of) character or custom eval function}
#'        \item{record}{Boolean, TRUE will record iteration message to \code{booster$record_evals}}
#'        \item{colnames}{feature names, if not null, will use this to overwrite the names in dataset}
#'        \item{categorical_feature}{list of str or int. type int represents index, type str represents feature names}
36
#'        \item{reset_data}{Boolean, setting it to TRUE (not the default value) will transform the booster model
James Lamb's avatar
James Lamb committed
37
38
39
#'                          into a predictor model which frees up memory and the original datasets}
#'         \item{boosting}{Boosting type. \code{"gbdt"} or \code{"dart"}}
#'         \item{num_leaves}{number of leaves in one tree. defaults to 127}
40
#'         \item{max_depth}{Limit the max depth for tree model. This is used to deal with
James Lamb's avatar
James Lamb committed
41
42
#'                          overfit when #data is small. Tree still grow by leaf-wise.}
#'          \item{num_threads}{Number of threads for LightGBM. For the best speed, set this to
43
#'                             the number of real CPU cores, not the number of threads (most
James Lamb's avatar
James Lamb committed
44
45
#'                             CPU using hyper-threading to generate 2 threads per CPU core).}
#'     }
Guolin Ke's avatar
Guolin Ke committed
46
#' @export
47
48
49
50
51
52
53
54
55
56
57
58
lightgbm <- function(data,
                     label = NULL,
                     weight = NULL,
                     params = list(),
                     nrounds = 10,
                     verbose = 1,
                     eval_freq = 1L,
                     early_stopping_rounds = NULL,
                     save_name = "lightgbm.model",
                     init_model = NULL,
                     callbacks = list(),
                     ...) {
59

60
  # Set data to a temporary variable
Guolin Ke's avatar
Guolin Ke committed
61
  dtrain <- data
62
63
64
  if (nrounds <= 0) {
    stop("nrounds should be greater than zero")
  }
65
  # Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
66
67
  if (!lgb.is.Dataset(dtrain)) {
    dtrain <- lgb.Dataset(data, label = label, weight = weight)
Guolin Ke's avatar
Guolin Ke committed
68
  }
Guolin Ke's avatar
Guolin Ke committed
69

70
  # Set validation as oneself
Guolin Ke's avatar
Guolin Ke committed
71
  valids <- list()
72
73
74
  if (verbose > 0) {
    valids$train = dtrain
  }
75

76
  # Train a model using the regular way
77
78
79
80
81
82
83
84
85
86
87
88
  bst <- lgb.train(
    params = params
    , data = dtrain
    , nrounds = nrounds
    , valids = valids
    , verbose = verbose
    , eval_freq = eval_freq
    , early_stopping_rounds = early_stopping_rounds
    , init_model = init_model
    , callbacks = callbacks
    , ...
  )
89

90
  # Store model under a specific name
Guolin Ke's avatar
Guolin Ke committed
91
  bst$save_model(save_name)
92

93
94
  # Return booster
  return(bst)
Guolin Ke's avatar
Guolin Ke committed
95
96
97
}

#' Training part from Mushroom Data Set
98
#'
Guolin Ke's avatar
Guolin Ke committed
99
100
#' This data set is originally from the Mushroom data set,
#' UCI Machine Learning Repository.
101
#'
Guolin Ke's avatar
Guolin Ke committed
102
#' This data set includes the following fields:
103
#'
Guolin Ke's avatar
Guolin Ke committed
104
105
106
107
108
109
110
#' \itemize{
#'  \item \code{label} the label for each record
#'  \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
#' }
#'
#' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
111
112
113
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
Guolin Ke's avatar
Guolin Ke committed
114
#' School of Information and Computer Science.
115
#'
Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
#' @docType data
#' @keywords datasets
#' @name agaricus.train
#' @usage data(agaricus.train)
120
#' @format A list containing a label vector, and a dgCMatrix object with 6513
Guolin Ke's avatar
Guolin Ke committed
121
122
123
124
125
126
127
#' rows and 127 variables
NULL

#' Test part from Mushroom Data Set
#'
#' This data set is originally from the Mushroom data set,
#' UCI Machine Learning Repository.
128
#'
Guolin Ke's avatar
Guolin Ke committed
129
#' This data set includes the following fields:
130
#'
Guolin Ke's avatar
Guolin Ke committed
131
132
133
134
135
136
137
#' \itemize{
#'  \item \code{label} the label for each record
#'  \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
#' }
#'
#' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
138
139
140
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
Guolin Ke's avatar
Guolin Ke committed
141
#' School of Information and Computer Science.
142
#'
Guolin Ke's avatar
Guolin Ke committed
143
144
145
146
#' @docType data
#' @keywords datasets
#' @name agaricus.test
#' @usage data(agaricus.test)
147
#' @format A list containing a label vector, and a dgCMatrix object with 1611
Guolin Ke's avatar
Guolin Ke committed
148
149
150
#' rows and 126 variables
NULL

151
152
153
154
155
156
157
158
159
160
#' Bank Marketing Data Set
#'
#' This data set is originally from the Bank Marketing data set,
#' UCI Machine Learning Repository.
#'
#' It contains only the following: bank.csv with 10% of the examples and 17 inputs,
#' randomly selected from 3 (older version of this dataset with less inputs).
#'
#' @references
#' http://archive.ics.uci.edu/ml/datasets/Bank+Marketing
161
#'
162
163
164
165
166
167
168
169
170
171
#' S. Moro, P. Cortez and P. Rita. (2014)
#' A Data-Driven Approach to Predict the Success of Bank Telemarketing. Decision Support Systems
#'
#' @docType data
#' @keywords datasets
#' @name bank
#' @usage data(bank)
#' @format A data.table with 4521 rows and 17 variables
NULL

Guolin Ke's avatar
Guolin Ke committed
172
# Various imports
Guolin Ke's avatar
Guolin Ke committed
173
#' @import methods
Guolin Ke's avatar
Guolin Ke committed
174
#' @importFrom R6 R6Class
James Lamb's avatar
James Lamb committed
175
#' @useDynLib lib_lightgbm , .registration = TRUE
176
NULL
James Lamb's avatar
James Lamb committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

# Suppress false positive warnings from R CMD CHECK about
# "unrecognized global variable"
globalVariables(c(
    "."
    , ".N"
    , ".SD"
    , "Contribution"
    , "Cover"
    , "Feature"
    , "Frequency"
    , "Gain"
    , "internal_count"
    , "internal_value"
    , "leaf_index"
    , "leaf_parent"
    , "leaf_value"
    , "node_parent"
    , "split_feature"
    , "split_gain"
    , "split_index"
    , "tree_index"
))