lightgbm.R 7.1 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#' @name lgb_shared_params
#' @title Shared parameter docs
#' @description Parameter docs shared by \code{lgb.train}, \code{lgb.cv}, and \code{lightgbm}
#' @param callbacks list of callback functions
#'        List of callback functions that are applied at each iteration.
#' @param data a \code{lgb.Dataset} object, used for training
#' @param early_stopping_rounds int
#'        Activates early stopping.
#'        Requires at least one validation data and one metric
#'        If there's more than one, will check all of them except the training data
#'        Returns the model with (best_iter + early_stopping_rounds)
#'        If early stopping occurs, the model will have 'best_iter' field
#' @param eval_freq evaluation output frequency, only effect when verbose > 0
#' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
#' @param nrounds number of training rounds
#' @param params List of parameters
#' @param verbose verbosity for output, if <= 0, also will disable the print of evaluation during training
NULL


#' @title Train a LightGBM model
#' @name lightgbm
23
#' @description Simple interface for training a LightGBM model.
James Lamb's avatar
James Lamb committed
24
25
26
27
28
29
30
#' @inheritParams lgb_shared_params
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param save_name File name to use when writing the trained model to disk. Should end in ".model".
#' @param ... Additional arguments passed to \code{\link{lgb.train}}. For example
#'     \itemize{
#'        \item{valids}{a list of \code{lgb.Dataset} objects, used for validation}
31
#'        \item{obj}{objective function, can be character or custom objective function. Examples include
James Lamb's avatar
James Lamb committed
32
33
34
35
36
37
#'                   \code{regression}, \code{regression_l1}, \code{huber},
#'                    \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}}
#'        \item{eval}{evaluation function, can be (a list of) character or custom eval function}
#'        \item{record}{Boolean, TRUE will record iteration message to \code{booster$record_evals}}
#'        \item{colnames}{feature names, if not null, will use this to overwrite the names in dataset}
#'        \item{categorical_feature}{list of str or int. type int represents index, type str represents feature names}
38
#'        \item{reset_data}{Boolean, setting it to TRUE (not the default value) will transform the booster model
James Lamb's avatar
James Lamb committed
39
40
41
#'                          into a predictor model which frees up memory and the original datasets}
#'         \item{boosting}{Boosting type. \code{"gbdt"} or \code{"dart"}}
#'         \item{num_leaves}{number of leaves in one tree. defaults to 127}
42
#'         \item{max_depth}{Limit the max depth for tree model. This is used to deal with
James Lamb's avatar
James Lamb committed
43
44
#'                          overfit when #data is small. Tree still grow by leaf-wise.}
#'          \item{num_threads}{Number of threads for LightGBM. For the best speed, set this to
45
#'                             the number of real CPU cores, not the number of threads (most
James Lamb's avatar
James Lamb committed
46
47
#'                             CPU using hyper-threading to generate 2 threads per CPU core).}
#'     }
Guolin Ke's avatar
Guolin Ke committed
48
#' @export
49
50
51
52
53
54
55
56
57
58
59
60
lightgbm <- function(data,
                     label = NULL,
                     weight = NULL,
                     params = list(),
                     nrounds = 10,
                     verbose = 1,
                     eval_freq = 1L,
                     early_stopping_rounds = NULL,
                     save_name = "lightgbm.model",
                     init_model = NULL,
                     callbacks = list(),
                     ...) {
61

62
  # Set data to a temporary variable
Guolin Ke's avatar
Guolin Ke committed
63
  dtrain <- data
64
65
66
  if (nrounds <= 0) {
    stop("nrounds should be greater than zero")
  }
67
  # Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
68
69
  if (!lgb.is.Dataset(dtrain)) {
    dtrain <- lgb.Dataset(data, label = label, weight = weight)
Guolin Ke's avatar
Guolin Ke committed
70
  }
Guolin Ke's avatar
Guolin Ke committed
71

72
  # Set validation as oneself
Guolin Ke's avatar
Guolin Ke committed
73
  valids <- list()
74
75
76
  if (verbose > 0) {
    valids$train = dtrain
  }
77

78
  # Train a model using the regular way
79
  bst <- lgb.train(params, dtrain, nrounds, valids, verbose = verbose, eval_freq = eval_freq,
Guolin Ke's avatar
Guolin Ke committed
80
81
                   early_stopping_rounds = early_stopping_rounds,
                   init_model = init_model, callbacks = callbacks, ...)
82

83
  # Store model under a specific name
Guolin Ke's avatar
Guolin Ke committed
84
  bst$save_model(save_name)
85

86
87
  # Return booster
  return(bst)
Guolin Ke's avatar
Guolin Ke committed
88
89
90
}

#' Training part from Mushroom Data Set
91
#'
Guolin Ke's avatar
Guolin Ke committed
92
93
#' This data set is originally from the Mushroom data set,
#' UCI Machine Learning Repository.
94
#'
Guolin Ke's avatar
Guolin Ke committed
95
#' This data set includes the following fields:
96
#'
Guolin Ke's avatar
Guolin Ke committed
97
98
99
100
101
102
103
#' \itemize{
#'  \item \code{label} the label for each record
#'  \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
#' }
#'
#' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
104
105
106
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
Guolin Ke's avatar
Guolin Ke committed
107
#' School of Information and Computer Science.
108
#'
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
#' @docType data
#' @keywords datasets
#' @name agaricus.train
#' @usage data(agaricus.train)
113
#' @format A list containing a label vector, and a dgCMatrix object with 6513
Guolin Ke's avatar
Guolin Ke committed
114
115
116
117
118
119
120
#' rows and 127 variables
NULL

#' Test part from Mushroom Data Set
#'
#' This data set is originally from the Mushroom data set,
#' UCI Machine Learning Repository.
121
#'
Guolin Ke's avatar
Guolin Ke committed
122
#' This data set includes the following fields:
123
#'
Guolin Ke's avatar
Guolin Ke committed
124
125
126
127
128
129
130
#' \itemize{
#'  \item \code{label} the label for each record
#'  \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
#' }
#'
#' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
131
132
133
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
Guolin Ke's avatar
Guolin Ke committed
134
#' School of Information and Computer Science.
135
#'
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
#' @docType data
#' @keywords datasets
#' @name agaricus.test
#' @usage data(agaricus.test)
140
#' @format A list containing a label vector, and a dgCMatrix object with 1611
Guolin Ke's avatar
Guolin Ke committed
141
142
143
#' rows and 126 variables
NULL

144
145
146
147
148
149
150
151
152
153
#' Bank Marketing Data Set
#'
#' This data set is originally from the Bank Marketing data set,
#' UCI Machine Learning Repository.
#'
#' It contains only the following: bank.csv with 10% of the examples and 17 inputs,
#' randomly selected from 3 (older version of this dataset with less inputs).
#'
#' @references
#' http://archive.ics.uci.edu/ml/datasets/Bank+Marketing
154
#'
155
156
157
158
159
160
161
162
163
164
#' S. Moro, P. Cortez and P. Rita. (2014)
#' A Data-Driven Approach to Predict the Success of Bank Telemarketing. Decision Support Systems
#'
#' @docType data
#' @keywords datasets
#' @name bank
#' @usage data(bank)
#' @format A data.table with 4521 rows and 17 variables
NULL

Guolin Ke's avatar
Guolin Ke committed
165
# Various imports
Guolin Ke's avatar
Guolin Ke committed
166
#' @import methods
Guolin Ke's avatar
Guolin Ke committed
167
#' @importFrom R6 R6Class
James Lamb's avatar
James Lamb committed
168
#' @useDynLib lib_lightgbm , .registration = TRUE
169
NULL
James Lamb's avatar
James Lamb committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

# Suppress false positive warnings from R CMD CHECK about
# "unrecognized global variable"
globalVariables(c(
    "."
    , ".N"
    , ".SD"
    , "Contribution"
    , "Cover"
    , "Feature"
    , "Frequency"
    , "Gain"
    , "internal_count"
    , "internal_value"
    , "leaf_index"
    , "leaf_parent"
    , "leaf_value"
    , "node_parent"
    , "split_feature"
    , "split_gain"
    , "split_index"
    , "tree_index"
))