lightgbm.R 7.23 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
3
4
5
6
#' @name lgb_shared_params
#' @title Shared parameter docs
#' @description Parameter docs shared by \code{lgb.train}, \code{lgb.cv}, and \code{lightgbm}
#' @param callbacks list of callback functions
#'        List of callback functions that are applied at each iteration.
#' @param data a \code{lgb.Dataset} object, used for training
7
8
9
10
#' @param early_stopping_rounds int. Activates early stopping. Requires at least one validation data
#'                              and one metric. If there's more than one, will check all of them
#'                              except the training data. Returns the model with (best_iter + early_stopping_rounds).
#'                              If early stopping occurs, the model will have 'best_iter' field.
James Lamb's avatar
James Lamb committed
11
12
13
14
15
16
17
18
19
20
#' @param eval_freq evaluation output frequency, only effect when verbose > 0
#' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
#' @param nrounds number of training rounds
#' @param params List of parameters
#' @param verbose verbosity for output, if <= 0, also will disable the print of evaluation during training
NULL


#' @title Train a LightGBM model
#' @name lightgbm
21
#' @description Simple interface for training a LightGBM model.
James Lamb's avatar
James Lamb committed
22
23
24
25
26
27
28
#' @inheritParams lgb_shared_params
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param save_name File name to use when writing the trained model to disk. Should end in ".model".
#' @param ... Additional arguments passed to \code{\link{lgb.train}}. For example
#'     \itemize{
#'        \item{valids}{a list of \code{lgb.Dataset} objects, used for validation}
29
#'        \item{obj}{objective function, can be character or custom objective function. Examples include
James Lamb's avatar
James Lamb committed
30
31
32
33
34
35
#'                   \code{regression}, \code{regression_l1}, \code{huber},
#'                    \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}}
#'        \item{eval}{evaluation function, can be (a list of) character or custom eval function}
#'        \item{record}{Boolean, TRUE will record iteration message to \code{booster$record_evals}}
#'        \item{colnames}{feature names, if not null, will use this to overwrite the names in dataset}
#'        \item{categorical_feature}{list of str or int. type int represents index, type str represents feature names}
36
#'        \item{reset_data}{Boolean, setting it to TRUE (not the default value) will transform the booster model
James Lamb's avatar
James Lamb committed
37
38
39
#'                          into a predictor model which frees up memory and the original datasets}
#'         \item{boosting}{Boosting type. \code{"gbdt"} or \code{"dart"}}
#'         \item{num_leaves}{number of leaves in one tree. defaults to 127}
40
#'         \item{max_depth}{Limit the max depth for tree model. This is used to deal with
James Lamb's avatar
James Lamb committed
41
42
#'                          overfit when #data is small. Tree still grow by leaf-wise.}
#'          \item{num_threads}{Number of threads for LightGBM. For the best speed, set this to
43
#'                             the number of real CPU cores, not the number of threads (most
James Lamb's avatar
James Lamb committed
44
45
#'                             CPU using hyper-threading to generate 2 threads per CPU core).}
#'     }
Guolin Ke's avatar
Guolin Ke committed
46
#' @export
47
48
49
50
lightgbm <- function(data,
                     label = NULL,
                     weight = NULL,
                     params = list(),
51
52
                     nrounds = 10L,
                     verbose = 1L,
53
54
55
56
57
58
                     eval_freq = 1L,
                     early_stopping_rounds = NULL,
                     save_name = "lightgbm.model",
                     init_model = NULL,
                     callbacks = list(),
                     ...) {
59

60
  # Set data to a temporary variable
Guolin Ke's avatar
Guolin Ke committed
61
  dtrain <- data
62
  if (nrounds <= 0L) {
63
64
    stop("nrounds should be greater than zero")
  }
65
  # Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
66
67
  if (!lgb.is.Dataset(dtrain)) {
    dtrain <- lgb.Dataset(data, label = label, weight = weight)
Guolin Ke's avatar
Guolin Ke committed
68
  }
Guolin Ke's avatar
Guolin Ke committed
69

70
  # Set validation as oneself
Guolin Ke's avatar
Guolin Ke committed
71
  valids <- list()
72
73
  if (verbose > 0L) {
    valids$train <- dtrain
74
  }
75

76
  # Train a model using the regular way
77
78
79
80
81
82
83
84
85
86
87
88
  bst <- lgb.train(
    params = params
    , data = dtrain
    , nrounds = nrounds
    , valids = valids
    , verbose = verbose
    , eval_freq = eval_freq
    , early_stopping_rounds = early_stopping_rounds
    , init_model = init_model
    , callbacks = callbacks
    , ...
  )
89

90
  # Store model under a specific name
Guolin Ke's avatar
Guolin Ke committed
91
  bst$save_model(save_name)
92

93
94
  # Return booster
  return(bst)
Guolin Ke's avatar
Guolin Ke committed
95
96
97
}

#' Training part from Mushroom Data Set
98
#'
Guolin Ke's avatar
Guolin Ke committed
99
100
#' This data set is originally from the Mushroom data set,
#' UCI Machine Learning Repository.
101
#'
Guolin Ke's avatar
Guolin Ke committed
102
#' This data set includes the following fields:
103
#'
Guolin Ke's avatar
Guolin Ke committed
104
105
106
107
108
109
110
#' \itemize{
#'  \item \code{label} the label for each record
#'  \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
#' }
#'
#' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
111
112
113
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
Guolin Ke's avatar
Guolin Ke committed
114
#' School of Information and Computer Science.
115
#'
Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
#' @docType data
#' @keywords datasets
#' @name agaricus.train
#' @usage data(agaricus.train)
120
#' @format A list containing a label vector, and a dgCMatrix object with 6513
Guolin Ke's avatar
Guolin Ke committed
121
122
123
124
125
126
127
#' rows and 127 variables
NULL

#' Test part from Mushroom Data Set
#'
#' This data set is originally from the Mushroom data set,
#' UCI Machine Learning Repository.
128
#'
Guolin Ke's avatar
Guolin Ke committed
129
#' This data set includes the following fields:
130
#'
Guolin Ke's avatar
Guolin Ke committed
131
132
133
134
135
136
137
#' \itemize{
#'  \item \code{label} the label for each record
#'  \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
#' }
#'
#' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
138
139
140
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
Guolin Ke's avatar
Guolin Ke committed
141
#' School of Information and Computer Science.
142
#'
Guolin Ke's avatar
Guolin Ke committed
143
144
145
146
#' @docType data
#' @keywords datasets
#' @name agaricus.test
#' @usage data(agaricus.test)
147
#' @format A list containing a label vector, and a dgCMatrix object with 1611
Guolin Ke's avatar
Guolin Ke committed
148
149
150
#' rows and 126 variables
NULL

151
152
153
154
155
156
157
158
159
160
#' Bank Marketing Data Set
#'
#' This data set is originally from the Bank Marketing data set,
#' UCI Machine Learning Repository.
#'
#' It contains only the following: bank.csv with 10% of the examples and 17 inputs,
#' randomly selected from 3 (older version of this dataset with less inputs).
#'
#' @references
#' http://archive.ics.uci.edu/ml/datasets/Bank+Marketing
161
#'
162
163
164
165
166
167
168
169
170
171
#' S. Moro, P. Cortez and P. Rita. (2014)
#' A Data-Driven Approach to Predict the Success of Bank Telemarketing. Decision Support Systems
#'
#' @docType data
#' @keywords datasets
#' @name bank
#' @usage data(bank)
#' @format A data.table with 4521 rows and 17 variables
NULL

Guolin Ke's avatar
Guolin Ke committed
172
# Various imports
Guolin Ke's avatar
Guolin Ke committed
173
#' @import methods
Guolin Ke's avatar
Guolin Ke committed
174
#' @importFrom R6 R6Class
James Lamb's avatar
James Lamb committed
175
#' @useDynLib lib_lightgbm , .registration = TRUE
176
NULL
James Lamb's avatar
James Lamb committed
177
178
179
180
181
182
183

# Suppress false positive warnings from R CMD CHECK about
# "unrecognized global variable"
globalVariables(c(
    "."
    , ".N"
    , ".SD"
184
    , "abs_contribution"
James Lamb's avatar
James Lamb committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    , "Contribution"
    , "Cover"
    , "Feature"
    , "Frequency"
    , "Gain"
    , "internal_count"
    , "internal_value"
    , "leaf_index"
    , "leaf_parent"
    , "leaf_value"
    , "node_parent"
    , "split_feature"
    , "split_gain"
    , "split_index"
    , "tree_index"
))