lgb.train.R 6.14 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
#' Main training logic for LightGBM
2
#'
Guolin Ke's avatar
Guolin Ke committed
3
4
5
#' @param params List of parameters
#' @param data a \code{lgb.Dataset} object, used for training
#' @param nrounds number of training rounds
6
#' @param valids a list of \code{lgb.Dataset} objects, used for validation
Guolin Ke's avatar
Guolin Ke committed
7
#' @param obj objective function, can be character or custom objective function
8
#' @param eval evaluation function, can be (a list of) character or custom eval function
9
10
11
#' @param verbose verbosity for output, if <= 0, also will disable the print of evalutaion during training
#' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals} 
#' @param eval_freq evalutaion output frequency, only effect when verbose > 0
12
#' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
Guolin Ke's avatar
Guolin Ke committed
13
14
15
16
17
18
19
20
21
22
23
24
25
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset
#' @param categorical_feature list of str or int
#'        type int represents index,
#'        type str represents feature names
#' @param early_stopping_rounds int
#'        Activates early stopping.
#'        Requires at least one validation data and one metric
#'        If there's more than one, will check all of them
#'        Returns the model with (best_iter + early_stopping_rounds)
#'        If early stopping occurs, the model will have 'best_iter' field
#' @param callbacks list of callback functions
#'        List of callback functions that are applied at each iteration.
#' @param ... other parameters, see parameters.md for more informations
26
#' @return a trained booster model \code{lgb.Booster}.
Guolin Ke's avatar
Guolin Ke committed
27
#' @examples
28
29
30
31
32
33
34
35
36
37
38
39
#' \dontrun{
#'   library(lightgbm)
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
#'   data(agaricus.test, package='lightgbm')
#'   test <- agaricus.test
#'   dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#'   params <- list(objective="regression", metric="l2")
#'   valids <- list(test=dtest)
#'   model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' }
Guolin Ke's avatar
Guolin Ke committed
40
41
#' @rdname lgb.train
#' @export
42
43
44
45
46
lgb.train <- function(params = list(), data, nrounds = 10,
                      valids                = list(),
                      obj                   = NULL,
                      eval                  = NULL,
                      verbose               = 1,
47
                      record                = TRUE,
48
49
50
51
52
53
54
55
                      eval_freq             = 1L,
                      init_model            = NULL,
                      colnames              = NULL,
                      categorical_feature   = NULL,
                      early_stopping_rounds = NULL,
                      callbacks             = list(), ...) {
  additional_params <- list(...)
  params         <- append(params, additional_params)
Guolin Ke's avatar
Guolin Ke committed
56
  params$verbose <- verbose
57
58
59
60
61
62
  params         <- lgb.check.obj(params, obj)
  params         <- lgb.check.eval(params, eval)
  fobj           <- NULL
  feval          <- NULL
  if (is.function(params$objective)) {
    fobj             <- params$objective
Guolin Ke's avatar
Guolin Ke committed
63
64
    params$objective <- "NONE"
  }
65
  if (is.function(eval)) { feval <- eval }
Guolin Ke's avatar
Guolin Ke committed
66
67
  lgb.check.params(params)
  predictor <- NULL
68
  if (is.character(init_model)) {
Guolin Ke's avatar
Guolin Ke committed
69
    predictor <- Predictor$new(init_model)
70
  } else if (lgb.is.Booster(init_model)) {
Guolin Ke's avatar
Guolin Ke committed
71
72
73
    predictor <- init_model$to_predictor()
  }
  begin_iteration <- 1
74
  if (!is.null(predictor)) {
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
    begin_iteration <- predictor$current_iter() + 1
  }
  end_iteration <- begin_iteration + nrounds - 1

  # check dataset
80
  if (!lgb.is.Dataset(data)) {
Guolin Ke's avatar
Guolin Ke committed
81
82
83
    stop("lgb.train: data only accepts lgb.Dataset object")
  }
  if (length(valids) > 0) {
84
85
86
    if (!is.list(valids) || !all(sapply(valids, lgb.is.Dataset))) {
      stop("lgb.train: valids must be a list of lgb.Dataset elements")
    }
Guolin Ke's avatar
Guolin Ke committed
87
    evnames <- names(valids)
88
89
90
    if (is.null(evnames) || !all(nzchar(evnames))) {
      stop("lgb.train: each element of the valids must have a name tag")
    }
Guolin Ke's avatar
Guolin Ke committed
91
92
93
94
  }

  data$update_params(params)
  data$.__enclos_env__$private$set_predictor(predictor)
95
  if (!is.null(colnames)) { data$set_colnames(colnames) }
Guolin Ke's avatar
Guolin Ke committed
96
  data$set_categorical_feature(categorical_feature)
Guolin Ke's avatar
Guolin Ke committed
97
  data$construct()
Guolin Ke's avatar
Guolin Ke committed
98
  vaild_contain_train <- FALSE
99
100
101
  train_data_name     <- "train"
  reduced_valid_sets  <- list()
  if (length(valids) > 0) {
Guolin Ke's avatar
Guolin Ke committed
102
103
    for (key in names(valids)) {
      valid_data <- valids[[key]]
104
      if (identical(data, valid_data)) {
Guolin Ke's avatar
Guolin Ke committed
105
        vaild_contain_train <- TRUE
106
        train_data_name     <- key
Guolin Ke's avatar
Guolin Ke committed
107
108
109
110
111
112
113
114
        next
      }
      valid_data$update_params(params)
      valid_data$set_reference(data)
      reduced_valid_sets[[key]] <- valid_data
    }
  }
  # process callbacks
115
  if (verbose > 0 & eval_freq > 0) {
Guolin Ke's avatar
Guolin Ke committed
116
117
118
    callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
  }

119
  if (record & length(valids) > 0) {
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
124
    callbacks <- add.cb(callbacks, cb.record.evaluation())
  }

  # Early stopping callback
  if (!is.null(early_stopping_rounds)) {
125
126
    if (early_stopping_rounds > 0) {
      callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, verbose = verbose))
Guolin Ke's avatar
Guolin Ke committed
127
128
129
130
131
132
    }
  }

  cb <- categorize.callbacks(callbacks)

  # construct booster
133
134
  booster <- Booster$new(params = params, train_set = data)
  if (vaild_contain_train) { booster$set_train_data_name(train_data_name) }
Guolin Ke's avatar
Guolin Ke committed
135
136
137
138
139
  for (key in names(reduced_valid_sets)) {
    booster$add_valid(reduced_valid_sets[[key]], key)
  }

  # callback env
140
141
  env                 <- CB_ENV$new()
  env$model           <- booster
Guolin Ke's avatar
Guolin Ke committed
142
  env$begin_iteration <- begin_iteration
143
  env$end_iteration   <- end_iteration
Guolin Ke's avatar
Guolin Ke committed
144
145

  #start training
146
  for (i in seq(from = begin_iteration, to = end_iteration)) {
Guolin Ke's avatar
Guolin Ke committed
147
148
    env$iteration <- i
    env$eval_list <- list()
149
    for (f in cb$pre_iter) { f(env) }
Guolin Ke's avatar
Guolin Ke committed
150
    # update one iter
151
    booster$update(fobj = fobj)
Guolin Ke's avatar
Guolin Ke committed
152
153
154

    # collect eval result
    eval_list <- list()
155
156
157
    if (length(valids) > 0) {
      if (vaild_contain_train) {
        eval_list <- append(eval_list, booster$eval_train(feval = feval))
Guolin Ke's avatar
Guolin Ke committed
158
      }
159
      eval_list <- append(eval_list, booster$eval_valid(feval = feval))
Guolin Ke's avatar
Guolin Ke committed
160
161
    }
    env$eval_list <- eval_list
162
    for (f in cb$post_iter) { f(env) }
Guolin Ke's avatar
Guolin Ke committed
163
    # met early stopping
164
    if (env$met_early_stop) break
Guolin Ke's avatar
Guolin Ke committed
165
166
  }

167
  booster
Guolin Ke's avatar
Guolin Ke committed
168
}