lightgbm.R 7.12 KB
Newer Older
James Lamb's avatar
James Lamb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

#' @name lgb_shared_params
#' @title Shared parameter docs
#' @description Parameter docs shared by \code{lgb.train}, \code{lgb.cv}, and \code{lightgbm}
#' @param callbacks list of callback functions
#'        List of callback functions that are applied at each iteration.
#' @param data a \code{lgb.Dataset} object, used for training
#' @param early_stopping_rounds int
#'        Activates early stopping.
#'        Requires at least one validation data and one metric
#'        If there's more than one, will check all of them except the training data
#'        Returns the model with (best_iter + early_stopping_rounds)
#'        If early stopping occurs, the model will have 'best_iter' field
#' @param eval_freq evaluation output frequency, only effect when verbose > 0
#' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
#' @param nrounds number of training rounds
#' @param params List of parameters
#' @param verbose verbosity for output, if <= 0, also will disable the print of evaluation during training
NULL


#' @title Train a LightGBM model
#' @name lightgbm
#' @description Simple interface for training an LightGBM model.
#' @inheritParams lgb_shared_params
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param save_name File name to use when writing the trained model to disk. Should end in ".model".
#' @param ... Additional arguments passed to \code{\link{lgb.train}}. For example
#'     \itemize{
#'        \item{valids}{a list of \code{lgb.Dataset} objects, used for validation}
#'        \item{obj}{objective function, can be character or custom objective function. Examples include 
#'                   \code{regression}, \code{regression_l1}, \code{huber},
#'                    \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}}
#'        \item{eval}{evaluation function, can be (a list of) character or custom eval function}
#'        \item{record}{Boolean, TRUE will record iteration message to \code{booster$record_evals}}
#'        \item{colnames}{feature names, if not null, will use this to overwrite the names in dataset}
#'        \item{categorical_feature}{list of str or int. type int represents index, type str represents feature names}
#'        \item{reset_data}{Boolean, setting it to TRUE (not the default value) will transform the booster model 
#'                          into a predictor model which frees up memory and the original datasets}
#'         \item{boosting}{Boosting type. \code{"gbdt"} or \code{"dart"}}
#'         \item{num_leaves}{number of leaves in one tree. defaults to 127}
#'         \item{max_depth}{Limit the max depth for tree model. This is used to deal with 
#'                          overfit when #data is small. Tree still grow by leaf-wise.}
#'          \item{num_threads}{Number of threads for LightGBM. For the best speed, set this to
#'                             the number of real CPU cores, not the number of threads (most 
#'                             CPU using hyper-threading to generate 2 threads per CPU core).}
#'     }
Guolin Ke's avatar
Guolin Ke committed
49
#' @export
50
51
52
53
54
55
56
57
58
59
60
61
62
63
lightgbm <- function(data,
                     label = NULL,
                     weight = NULL,
                     params = list(),
                     nrounds = 10,
                     verbose = 1,
                     eval_freq = 1L,
                     early_stopping_rounds = NULL,
                     save_name = "lightgbm.model",
                     init_model = NULL,
                     callbacks = list(),
                     ...) {
  
  # Set data to a temporary variable
Guolin Ke's avatar
Guolin Ke committed
64
  dtrain <- data
65
66
67
  if (nrounds <= 0) {
    stop("nrounds should be greater than zero")
  }
68
  # Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
69
70
  if (!lgb.is.Dataset(dtrain)) {
    dtrain <- lgb.Dataset(data, label = label, weight = weight)
Guolin Ke's avatar
Guolin Ke committed
71
  }
Guolin Ke's avatar
Guolin Ke committed
72

73
  # Set validation as oneself
Guolin Ke's avatar
Guolin Ke committed
74
  valids <- list()
75
76
77
78
79
  if (verbose > 0) {
    valids$train = dtrain
  }
  
  # Train a model using the regular way
80
  bst <- lgb.train(params, dtrain, nrounds, valids, verbose = verbose, eval_freq = eval_freq,
Guolin Ke's avatar
Guolin Ke committed
81
82
                   early_stopping_rounds = early_stopping_rounds,
                   init_model = init_model, callbacks = callbacks, ...)
83
84
  
  # Store model under a specific name
Guolin Ke's avatar
Guolin Ke committed
85
  bst$save_model(save_name)
86
87
88
  
  # Return booster
  return(bst)
Guolin Ke's avatar
Guolin Ke committed
89
90
91
}

#' Training part from Mushroom Data Set
92
#'
Guolin Ke's avatar
Guolin Ke committed
93
94
#' This data set is originally from the Mushroom data set,
#' UCI Machine Learning Repository.
95
#'
Guolin Ke's avatar
Guolin Ke committed
96
#' This data set includes the following fields:
97
#'
Guolin Ke's avatar
Guolin Ke committed
98
99
100
101
102
103
104
#' \itemize{
#'  \item \code{label} the label for each record
#'  \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
#' }
#'
#' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
105
106
107
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
Guolin Ke's avatar
Guolin Ke committed
108
#' School of Information and Computer Science.
109
#'
Guolin Ke's avatar
Guolin Ke committed
110
111
112
113
#' @docType data
#' @keywords datasets
#' @name agaricus.train
#' @usage data(agaricus.train)
114
#' @format A list containing a label vector, and a dgCMatrix object with 6513
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
120
121
#' rows and 127 variables
NULL

#' Test part from Mushroom Data Set
#'
#' This data set is originally from the Mushroom data set,
#' UCI Machine Learning Repository.
122
#'
Guolin Ke's avatar
Guolin Ke committed
123
#' This data set includes the following fields:
124
#'
Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
129
130
131
#' \itemize{
#'  \item \code{label} the label for each record
#'  \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
#' }
#'
#' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
132
133
134
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
Guolin Ke's avatar
Guolin Ke committed
135
#' School of Information and Computer Science.
136
#'
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
#' @docType data
#' @keywords datasets
#' @name agaricus.test
#' @usage data(agaricus.test)
141
#' @format A list containing a label vector, and a dgCMatrix object with 1611
Guolin Ke's avatar
Guolin Ke committed
142
143
144
#' rows and 126 variables
NULL

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#' Bank Marketing Data Set
#'
#' This data set is originally from the Bank Marketing data set,
#' UCI Machine Learning Repository.
#'
#' It contains only the following: bank.csv with 10% of the examples and 17 inputs,
#' randomly selected from 3 (older version of this dataset with less inputs).
#'
#' @references
#' http://archive.ics.uci.edu/ml/datasets/Bank+Marketing
#' 
#' S. Moro, P. Cortez and P. Rita. (2014)
#' A Data-Driven Approach to Predict the Success of Bank Telemarketing. Decision Support Systems
#'
#' @docType data
#' @keywords datasets
#' @name bank
#' @usage data(bank)
#' @format A data.table with 4521 rows and 17 variables
NULL

Guolin Ke's avatar
Guolin Ke committed
166
# Various imports
Guolin Ke's avatar
Guolin Ke committed
167
#' @import methods
Guolin Ke's avatar
Guolin Ke committed
168
#' @importFrom R6 R6Class
James Lamb's avatar
James Lamb committed
169
#' @useDynLib lib_lightgbm , .registration = TRUE
170
NULL
James Lamb's avatar
James Lamb committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

# Suppress false positive warnings from R CMD CHECK about
# "unrecognized global variable"
globalVariables(c(
    "."
    , ".N"
    , ".SD"
    , "Contribution"
    , "Cover"
    , "Feature"
    , "Frequency"
    , "Gain"
    , "internal_count"
    , "internal_value"
    , "leaf_index"
    , "leaf_parent"
    , "leaf_value"
    , "node_parent"
    , "split_feature"
    , "split_gain"
    , "split_index"
    , "tree_index"
))