lgb.train.R 6.55 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
#' Main training logic for LightGBM
2
#'
Guolin Ke's avatar
Guolin Ke committed
3
4
5
#' @param params List of parameters
#' @param data a \code{lgb.Dataset} object, used for training
#' @param nrounds number of training rounds
6
#' @param valids a list of \code{lgb.Dataset} objects, used for validation
7
8
9
10
11
12
13
14
#' @param obj objective function, can be character or custom objective function. Examples include 
#'        \code{regression}, \code{regression_l1}, \code{huber},
#'        \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}
#' @param boosting boosting type. \code{gbdt}, \code{dart}
#' @param num_leaves number of leaves in one tree. defaults to 127
#' @param max_depth Limit the max depth for tree model. This is used to deal with overfit when #data is small. 
#'        Tree still grow by leaf-wise.
#' @param num_threads Number of threads for LightGBM. For the best speed, set this to the number of real CPU cores, not the number of threads (most CPU using hyper-threading to generate 2 threads per CPU core).
15
#' @param eval evaluation function, can be (a list of) character or custom eval function
16
17
18
#' @param verbose verbosity for output, if <= 0, also will disable the print of evalutaion during training
#' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals} 
#' @param eval_freq evalutaion output frequency, only effect when verbose > 0
19
#' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
Guolin Ke's avatar
Guolin Ke committed
20
21
22
23
24
25
26
27
28
29
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset
#' @param early_stopping_rounds int
#'        Activates early stopping.
#'        Requires at least one validation data and one metric
#'        If there's more than one, will check all of them
#'        Returns the model with (best_iter + early_stopping_rounds)
#'        If early stopping occurs, the model will have 'best_iter' field
#' @param callbacks list of callback functions
#'        List of callback functions that are applied at each iteration.
#' @param ... other parameters, see parameters.md for more informations
30
#' @return a trained booster model \code{lgb.Booster}.
Guolin Ke's avatar
Guolin Ke committed
31
#' @examples
32
33
34
35
36
37
38
39
40
41
42
43
#' \dontrun{
#'   library(lightgbm)
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
#'   data(agaricus.test, package='lightgbm')
#'   test <- agaricus.test
#'   dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#'   params <- list(objective="regression", metric="l2")
#'   valids <- list(test=dtest)
#'   model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' }
Guolin Ke's avatar
Guolin Ke committed
44
45
#' @rdname lgb.train
#' @export
46
47
48
49
50
lgb.train <- function(params = list(), data, nrounds = 10,
                      valids                = list(),
                      obj                   = NULL,
                      eval                  = NULL,
                      verbose               = 1,
51
                      record                = TRUE,
52
53
54
55
56
57
58
                      eval_freq             = 1L,
                      init_model            = NULL,
                      colnames              = NULL,
                      early_stopping_rounds = NULL,
                      callbacks             = list(), ...) {
  additional_params <- list(...)
  params         <- append(params, additional_params)
Guolin Ke's avatar
Guolin Ke committed
59
  params$verbose <- verbose
60
61
62
63
64
65
  params         <- lgb.check.obj(params, obj)
  params         <- lgb.check.eval(params, eval)
  fobj           <- NULL
  feval          <- NULL
  if (is.function(params$objective)) {
    fobj             <- params$objective
Guolin Ke's avatar
Guolin Ke committed
66
67
    params$objective <- "NONE"
  }
68
  if (is.function(eval)) { feval <- eval }
Guolin Ke's avatar
Guolin Ke committed
69
70
  lgb.check.params(params)
  predictor <- NULL
71
  if (is.character(init_model)) {
Guolin Ke's avatar
Guolin Ke committed
72
    predictor <- Predictor$new(init_model)
73
  } else if (lgb.is.Booster(init_model)) {
Guolin Ke's avatar
Guolin Ke committed
74
75
76
    predictor <- init_model$to_predictor()
  }
  begin_iteration <- 1
77
  if (!is.null(predictor)) {
Guolin Ke's avatar
Guolin Ke committed
78
79
80
81
82
    begin_iteration <- predictor$current_iter() + 1
  }
  end_iteration <- begin_iteration + nrounds - 1

  # check dataset
83
  if (!lgb.is.Dataset(data)) {
Guolin Ke's avatar
Guolin Ke committed
84
85
86
    stop("lgb.train: data only accepts lgb.Dataset object")
  }
  if (length(valids) > 0) {
87
88
89
    if (!is.list(valids) || !all(sapply(valids, lgb.is.Dataset))) {
      stop("lgb.train: valids must be a list of lgb.Dataset elements")
    }
Guolin Ke's avatar
Guolin Ke committed
90
    evnames <- names(valids)
91
92
93
    if (is.null(evnames) || !all(nzchar(evnames))) {
      stop("lgb.train: each element of the valids must have a name tag")
    }
Guolin Ke's avatar
Guolin Ke committed
94
95
96
97
  }

  data$update_params(params)
  data$.__enclos_env__$private$set_predictor(predictor)
98
  if (!is.null(colnames)) { data$set_colnames(colnames) }
Guolin Ke's avatar
Guolin Ke committed
99
  data$construct()
Guolin Ke's avatar
Guolin Ke committed
100
  vaild_contain_train <- FALSE
101
102
103
  train_data_name     <- "train"
  reduced_valid_sets  <- list()
  if (length(valids) > 0) {
Guolin Ke's avatar
Guolin Ke committed
104
105
    for (key in names(valids)) {
      valid_data <- valids[[key]]
106
      if (identical(data, valid_data)) {
Guolin Ke's avatar
Guolin Ke committed
107
        vaild_contain_train <- TRUE
108
        train_data_name     <- key
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
114
115
116
        next
      }
      valid_data$update_params(params)
      valid_data$set_reference(data)
      reduced_valid_sets[[key]] <- valid_data
    }
  }
  # process callbacks
117
  if (verbose > 0 & eval_freq > 0) {
Guolin Ke's avatar
Guolin Ke committed
118
119
120
    callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
  }

121
  if (record & length(valids) > 0) {
Guolin Ke's avatar
Guolin Ke committed
122
123
124
125
126
    callbacks <- add.cb(callbacks, cb.record.evaluation())
  }

  # Early stopping callback
  if (!is.null(early_stopping_rounds)) {
127
128
    if (early_stopping_rounds > 0) {
      callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, verbose = verbose))
Guolin Ke's avatar
Guolin Ke committed
129
130
131
132
133
134
    }
  }

  cb <- categorize.callbacks(callbacks)

  # construct booster
135
136
  booster <- Booster$new(params = params, train_set = data)
  if (vaild_contain_train) { booster$set_train_data_name(train_data_name) }
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
141
  for (key in names(reduced_valid_sets)) {
    booster$add_valid(reduced_valid_sets[[key]], key)
  }

  # callback env
142
143
  env                 <- CB_ENV$new()
  env$model           <- booster
Guolin Ke's avatar
Guolin Ke committed
144
  env$begin_iteration <- begin_iteration
145
  env$end_iteration   <- end_iteration
Guolin Ke's avatar
Guolin Ke committed
146
147

  #start training
148
  for (i in seq(from = begin_iteration, to = end_iteration)) {
Guolin Ke's avatar
Guolin Ke committed
149
150
    env$iteration <- i
    env$eval_list <- list()
151
    for (f in cb$pre_iter) { f(env) }
Guolin Ke's avatar
Guolin Ke committed
152
    # update one iter
153
    booster$update(fobj = fobj)
Guolin Ke's avatar
Guolin Ke committed
154
155
156

    # collect eval result
    eval_list <- list()
157
158
159
    if (length(valids) > 0) {
      if (vaild_contain_train) {
        eval_list <- append(eval_list, booster$eval_train(feval = feval))
Guolin Ke's avatar
Guolin Ke committed
160
      }
161
      eval_list <- append(eval_list, booster$eval_valid(feval = feval))
Guolin Ke's avatar
Guolin Ke committed
162
163
    }
    env$eval_list <- eval_list
164
    for (f in cb$post_iter) { f(env) }
Guolin Ke's avatar
Guolin Ke committed
165
    # met early stopping
166
    if (env$met_early_stop) break
Guolin Ke's avatar
Guolin Ke committed
167
168
  }

169
  booster
Guolin Ke's avatar
Guolin Ke committed
170
}