lgb.train.R 6 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
#' Main training logic for LightGBM
2
#'
Guolin Ke's avatar
Guolin Ke committed
3
4
5
#' @param params List of parameters
#' @param data a \code{lgb.Dataset} object, used for training
#' @param nrounds number of training rounds
6
#' @param valids a list of \code{lgb.Dataset} objects, used for validation
Guolin Ke's avatar
Guolin Ke committed
7
#' @param obj objective function, can be character or custom objective function
8
#' @param eval evaluation function, can be (a list of) character or custom eval function
Guolin Ke's avatar
Guolin Ke committed
9
#' @param verbose verbosity for output
10
11
12
#'        if \code{verbose > 0}, also will record iteration message to \code{booster$record_evals}
#' @param eval_freq evalutaion output frequency
#' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
Guolin Ke's avatar
Guolin Ke committed
13
14
15
16
17
18
19
20
21
22
23
24
25
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset
#' @param categorical_feature list of str or int
#'        type int represents index,
#'        type str represents feature names
#' @param early_stopping_rounds int
#'        Activates early stopping.
#'        Requires at least one validation data and one metric
#'        If there's more than one, will check all of them
#'        Returns the model with (best_iter + early_stopping_rounds)
#'        If early stopping occurs, the model will have 'best_iter' field
#' @param callbacks list of callback functions
#'        List of callback functions that are applied at each iteration.
#' @param ... other parameters, see parameters.md for more informations
26
#' @return a trained booster model \code{lgb.Booster}.
Guolin Ke's avatar
Guolin Ke committed
27
#' @examples
28
29
30
31
32
33
34
35
36
37
38
39
#' \dontrun{
#'   library(lightgbm)
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
#'   data(agaricus.test, package='lightgbm')
#'   test <- agaricus.test
#'   dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#'   params <- list(objective="regression", metric="l2")
#'   valids <- list(test=dtest)
#'   model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' }
Guolin Ke's avatar
Guolin Ke committed
40
41
#' @rdname lgb.train
#' @export
42
43
44
45
46
47
48
49
50
51
52
53
54
lgb.train <- function(params = list(), data, nrounds = 10,
                      valids                = list(),
                      obj                   = NULL,
                      eval                  = NULL,
                      verbose               = 1,
                      eval_freq             = 1L,
                      init_model            = NULL,
                      colnames              = NULL,
                      categorical_feature   = NULL,
                      early_stopping_rounds = NULL,
                      callbacks             = list(), ...) {
  additional_params <- list(...)
  params         <- append(params, additional_params)
Guolin Ke's avatar
Guolin Ke committed
55
  params$verbose <- verbose
56
57
58
59
60
61
  params         <- lgb.check.obj(params, obj)
  params         <- lgb.check.eval(params, eval)
  fobj           <- NULL
  feval          <- NULL
  if (is.function(params$objective)) {
    fobj             <- params$objective
Guolin Ke's avatar
Guolin Ke committed
62
63
    params$objective <- "NONE"
  }
64
  if (is.function(eval)) { feval <- eval }
Guolin Ke's avatar
Guolin Ke committed
65
66
  lgb.check.params(params)
  predictor <- NULL
67
  if (is.character(init_model)) {
Guolin Ke's avatar
Guolin Ke committed
68
    predictor <- Predictor$new(init_model)
69
  } else if (lgb.is.Booster(init_model)) {
Guolin Ke's avatar
Guolin Ke committed
70
71
72
    predictor <- init_model$to_predictor()
  }
  begin_iteration <- 1
73
  if (!is.null(predictor)) {
Guolin Ke's avatar
Guolin Ke committed
74
75
76
77
78
    begin_iteration <- predictor$current_iter() + 1
  }
  end_iteration <- begin_iteration + nrounds - 1

  # check dataset
79
  if (!lgb.is.Dataset(data)) {
Guolin Ke's avatar
Guolin Ke committed
80
81
82
    stop("lgb.train: data only accepts lgb.Dataset object")
  }
  if (length(valids) > 0) {
83
84
85
    if (!is.list(valids) || !all(sapply(valids, lgb.is.Dataset))) {
      stop("lgb.train: valids must be a list of lgb.Dataset elements")
    }
Guolin Ke's avatar
Guolin Ke committed
86
    evnames <- names(valids)
87
88
89
    if (is.null(evnames) || !all(nzchar(evnames))) {
      stop("lgb.train: each element of the valids must have a name tag")
    }
Guolin Ke's avatar
Guolin Ke committed
90
91
92
93
  }

  data$update_params(params)
  data$.__enclos_env__$private$set_predictor(predictor)
94
  if (!is.null(colnames)) { data$set_colnames(colnames) }
Guolin Ke's avatar
Guolin Ke committed
95
  data$set_categorical_feature(categorical_feature)
Guolin Ke's avatar
Guolin Ke committed
96
  data$construct()
Guolin Ke's avatar
Guolin Ke committed
97
  vaild_contain_train <- FALSE
98
99
100
  train_data_name     <- "train"
  reduced_valid_sets  <- list()
  if (length(valids) > 0) {
Guolin Ke's avatar
Guolin Ke committed
101
102
    for (key in names(valids)) {
      valid_data <- valids[[key]]
103
      if (identical(data, valid_data)) {
Guolin Ke's avatar
Guolin Ke committed
104
        vaild_contain_train <- TRUE
105
        train_data_name     <- key
Guolin Ke's avatar
Guolin Ke committed
106
107
108
109
110
111
112
113
        next
      }
      valid_data$update_params(params)
      valid_data$set_reference(data)
      reduced_valid_sets[[key]] <- valid_data
    }
  }
  # process callbacks
114
  if (eval_freq > 0) {
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
120
121
122
123
    callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
  }

  if (verbose > 0 && length(valids) > 0) {
    callbacks <- add.cb(callbacks, cb.record.evaluation())
  }

  # Early stopping callback
  if (!is.null(early_stopping_rounds)) {
124
125
    if (early_stopping_rounds > 0) {
      callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, verbose = verbose))
Guolin Ke's avatar
Guolin Ke committed
126
127
128
129
130
131
    }
  }

  cb <- categorize.callbacks(callbacks)

  # construct booster
132
133
  booster <- Booster$new(params = params, train_set = data)
  if (vaild_contain_train) { booster$set_train_data_name(train_data_name) }
Guolin Ke's avatar
Guolin Ke committed
134
135
136
137
138
  for (key in names(reduced_valid_sets)) {
    booster$add_valid(reduced_valid_sets[[key]], key)
  }

  # callback env
139
140
  env                 <- CB_ENV$new()
  env$model           <- booster
Guolin Ke's avatar
Guolin Ke committed
141
  env$begin_iteration <- begin_iteration
142
  env$end_iteration   <- end_iteration
Guolin Ke's avatar
Guolin Ke committed
143
144

  #start training
145
  for (i in seq(from = begin_iteration, to = end_iteration)) {
Guolin Ke's avatar
Guolin Ke committed
146
147
    env$iteration <- i
    env$eval_list <- list()
148
    for (f in cb$pre_iter) { f(env) }
Guolin Ke's avatar
Guolin Ke committed
149
    # update one iter
150
    booster$update(fobj = fobj)
Guolin Ke's avatar
Guolin Ke committed
151
152
153

    # collect eval result
    eval_list <- list()
154
155
156
    if (length(valids) > 0) {
      if (vaild_contain_train) {
        eval_list <- append(eval_list, booster$eval_train(feval = feval))
Guolin Ke's avatar
Guolin Ke committed
157
      }
158
      eval_list <- append(eval_list, booster$eval_valid(feval = feval))
Guolin Ke's avatar
Guolin Ke committed
159
160
    }
    env$eval_list <- eval_list
161
    for (f in cb$post_iter) { f(env) }
Guolin Ke's avatar
Guolin Ke committed
162
    # met early stopping
163
    if (env$met_early_stop) break
Guolin Ke's avatar
Guolin Ke committed
164
165
  }

166
  booster
Guolin Ke's avatar
Guolin Ke committed
167
}