lgb.train.R 6.82 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
#' Main training logic for LightGBM
2
#'
Guolin Ke's avatar
Guolin Ke committed
3
4
5
#' @param params List of parameters
#' @param data a \code{lgb.Dataset} object, used for training
#' @param nrounds number of training rounds
6
#' @param valids a list of \code{lgb.Dataset} objects, used for validation
7
8
9
10
11
12
13
14
#' @param obj objective function, can be character or custom objective function. Examples include 
#'        \code{regression}, \code{regression_l1}, \code{huber},
#'        \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}
#' @param boosting boosting type. \code{gbdt}, \code{dart}
#' @param num_leaves number of leaves in one tree. defaults to 127
#' @param max_depth Limit the max depth for tree model. This is used to deal with overfit when #data is small. 
#'        Tree still grow by leaf-wise.
#' @param num_threads Number of threads for LightGBM. For the best speed, set this to the number of real CPU cores, not the number of threads (most CPU using hyper-threading to generate 2 threads per CPU core).
15
#' @param eval evaluation function, can be (a list of) character or custom eval function
16
17
18
#' @param verbose verbosity for output, if <= 0, also will disable the print of evalutaion during training
#' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals} 
#' @param eval_freq evalutaion output frequency, only effect when verbose > 0
19
#' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
Guolin Ke's avatar
Guolin Ke committed
20
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset
21
22
23
#' @param categorical_feature list of str or int
#'        type int represents index,
#'        type str represents feature names
Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
28
29
30
31
32
#' @param early_stopping_rounds int
#'        Activates early stopping.
#'        Requires at least one validation data and one metric
#'        If there's more than one, will check all of them
#'        Returns the model with (best_iter + early_stopping_rounds)
#'        If early stopping occurs, the model will have 'best_iter' field
#' @param callbacks list of callback functions
#'        List of callback functions that are applied at each iteration.
#' @param ... other parameters, see parameters.md for more informations
33
#' @return a trained booster model \code{lgb.Booster}.
Guolin Ke's avatar
Guolin Ke committed
34
#' @examples
35
36
37
38
39
40
41
42
43
44
45
46
#' \dontrun{
#'   library(lightgbm)
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
#'   data(agaricus.test, package='lightgbm')
#'   test <- agaricus.test
#'   dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#'   params <- list(objective="regression", metric="l2")
#'   valids <- list(test=dtest)
#'   model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' }
Guolin Ke's avatar
Guolin Ke committed
47
48
#' @rdname lgb.train
#' @export
49
50
51
52
53
lgb.train <- function(params = list(), data, nrounds = 10,
                      valids                = list(),
                      obj                   = NULL,
                      eval                  = NULL,
                      verbose               = 1,
54
                      record                = TRUE,
55
56
57
                      eval_freq             = 1L,
                      init_model            = NULL,
                      colnames              = NULL,
58
                      categorical_feature   = NULL,
59
60
61
62
                      early_stopping_rounds = NULL,
                      callbacks             = list(), ...) {
  additional_params <- list(...)
  params         <- append(params, additional_params)
Guolin Ke's avatar
Guolin Ke committed
63
  params$verbose <- verbose
64
65
66
67
68
69
  params         <- lgb.check.obj(params, obj)
  params         <- lgb.check.eval(params, eval)
  fobj           <- NULL
  feval          <- NULL
  if (is.function(params$objective)) {
    fobj             <- params$objective
Guolin Ke's avatar
Guolin Ke committed
70
71
    params$objective <- "NONE"
  }
72
  if (is.function(eval)) { feval <- eval }
Guolin Ke's avatar
Guolin Ke committed
73
74
  lgb.check.params(params)
  predictor <- NULL
75
  if (is.character(init_model)) {
Guolin Ke's avatar
Guolin Ke committed
76
    predictor <- Predictor$new(init_model)
77
  } else if (lgb.is.Booster(init_model)) {
Guolin Ke's avatar
Guolin Ke committed
78
79
80
    predictor <- init_model$to_predictor()
  }
  begin_iteration <- 1
81
  if (!is.null(predictor)) {
Guolin Ke's avatar
Guolin Ke committed
82
83
84
85
86
    begin_iteration <- predictor$current_iter() + 1
  }
  end_iteration <- begin_iteration + nrounds - 1

  # check dataset
87
  if (!lgb.is.Dataset(data)) {
Guolin Ke's avatar
Guolin Ke committed
88
89
90
    stop("lgb.train: data only accepts lgb.Dataset object")
  }
  if (length(valids) > 0) {
91
92
93
    if (!is.list(valids) || !all(sapply(valids, lgb.is.Dataset))) {
      stop("lgb.train: valids must be a list of lgb.Dataset elements")
    }
Guolin Ke's avatar
Guolin Ke committed
94
    evnames <- names(valids)
95
96
97
    if (is.null(evnames) || !all(nzchar(evnames))) {
      stop("lgb.train: each element of the valids must have a name tag")
    }
Guolin Ke's avatar
Guolin Ke committed
98
99
100
101
  }

  data$update_params(params)
  data$.__enclos_env__$private$set_predictor(predictor)
102
  if (!is.null(colnames)) { data$set_colnames(colnames) }
103
  if (!is.null(categorical_feature)) { data$set_categorical_feature(categorical_feature) }
Guolin Ke's avatar
Guolin Ke committed
104
  data$construct()
Guolin Ke's avatar
Guolin Ke committed
105
  vaild_contain_train <- FALSE
106
107
108
  train_data_name     <- "train"
  reduced_valid_sets  <- list()
  if (length(valids) > 0) {
Guolin Ke's avatar
Guolin Ke committed
109
110
    for (key in names(valids)) {
      valid_data <- valids[[key]]
111
      if (identical(data, valid_data)) {
Guolin Ke's avatar
Guolin Ke committed
112
        vaild_contain_train <- TRUE
113
        train_data_name     <- key
Guolin Ke's avatar
Guolin Ke committed
114
115
116
117
118
119
120
121
        next
      }
      valid_data$update_params(params)
      valid_data$set_reference(data)
      reduced_valid_sets[[key]] <- valid_data
    }
  }
  # process callbacks
122
  if (verbose > 0 & eval_freq > 0) {
Guolin Ke's avatar
Guolin Ke committed
123
124
125
    callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
  }

126
  if (record & length(valids) > 0) {
Guolin Ke's avatar
Guolin Ke committed
127
128
129
130
131
    callbacks <- add.cb(callbacks, cb.record.evaluation())
  }

  # Early stopping callback
  if (!is.null(early_stopping_rounds)) {
132
133
    if (early_stopping_rounds > 0) {
      callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, verbose = verbose))
Guolin Ke's avatar
Guolin Ke committed
134
135
136
137
138
139
    }
  }

  cb <- categorize.callbacks(callbacks)

  # construct booster
140
141
  booster <- Booster$new(params = params, train_set = data)
  if (vaild_contain_train) { booster$set_train_data_name(train_data_name) }
Guolin Ke's avatar
Guolin Ke committed
142
143
144
145
146
  for (key in names(reduced_valid_sets)) {
    booster$add_valid(reduced_valid_sets[[key]], key)
  }

  # callback env
147
148
  env                 <- CB_ENV$new()
  env$model           <- booster
Guolin Ke's avatar
Guolin Ke committed
149
  env$begin_iteration <- begin_iteration
150
  env$end_iteration   <- end_iteration
Guolin Ke's avatar
Guolin Ke committed
151
152

  #start training
153
  for (i in seq(from = begin_iteration, to = end_iteration)) {
Guolin Ke's avatar
Guolin Ke committed
154
155
    env$iteration <- i
    env$eval_list <- list()
156
    for (f in cb$pre_iter) { f(env) }
Guolin Ke's avatar
Guolin Ke committed
157
    # update one iter
158
    booster$update(fobj = fobj)
Guolin Ke's avatar
Guolin Ke committed
159
160
161

    # collect eval result
    eval_list <- list()
162
163
164
    if (length(valids) > 0) {
      if (vaild_contain_train) {
        eval_list <- append(eval_list, booster$eval_train(feval = feval))
Guolin Ke's avatar
Guolin Ke committed
165
      }
166
      eval_list <- append(eval_list, booster$eval_valid(feval = feval))
Guolin Ke's avatar
Guolin Ke committed
167
168
    }
    env$eval_list <- eval_list
169
    for (f in cb$post_iter) { f(env) }
Guolin Ke's avatar
Guolin Ke committed
170
    # met early stopping
171
    if (env$met_early_stop) break
Guolin Ke's avatar
Guolin Ke committed
172
173
  }

174
  booster
Guolin Ke's avatar
Guolin Ke committed
175
}