early_stopping.R 1.86 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
require(lightgbm)
require(methods)
3
4
5
6
7

# Load in the agaricus dataset
data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm")

Guolin Ke's avatar
Guolin Ke committed
8
9
dtrain <- lgb.Dataset(agaricus.train$data, label = agaricus.train$label)
dtest <- lgb.Dataset(agaricus.test$data, label = agaricus.test$label)
10
11
12
13
14
15

# Note: for customized objective function, we leave objective as default
# Note: what we are getting is margin value in prediction
# You must know what you are doing
param <- list(num_leaves = 4,
              learning_rate = 1)
Guolin Ke's avatar
Guolin Ke committed
16
17
valids <- list(eval = dtest)
num_round <- 20
18
19
20

# User define objective function, given prediction, return gradient and second order gradient
# This is loglikelihood loss
Guolin Ke's avatar
Guolin Ke committed
21
22
logregobj <- function(preds, dtrain) {
  labels <- getinfo(dtrain, "label")
23
  preds <- 1 / (1 + exp(-preds))
Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
  grad <- preds - labels
  hess <- preds * (1 - preds)
  return(list(grad = grad, hess = hess))
}
28
29

# User defined evaluation function, return a pair metric_name, result, higher_better
Guolin Ke's avatar
Guolin Ke committed
30
# NOTE: when you do customized loss function, the default prediction value is margin
31
32
33
# This may make buildin evalution metric not function properly
# For example, we are doing logistic loss, the prediction is score before logistic transformation
# The buildin evaluation error assumes input is after logistic transformation
Guolin Ke's avatar
Guolin Ke committed
34
35
36
# Take this in mind when you use the customization, and maybe you need write customized evaluation function
evalerror <- function(preds, dtrain) {
  labels <- getinfo(dtrain, "label")
37
  err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels)
38
  return(list(name = "error", value = err, higher_better = FALSE))
Guolin Ke's avatar
Guolin Ke committed
39
}
40
print("Start training with early Stopping setting")
Guolin Ke's avatar
Guolin Ke committed
41

42
43
44
45
46
47
bst <- lgb.train(param,
                 dtrain,
                 num_round,
                 valids,
                 objective = logregobj,
                 eval = evalerror,
Guolin Ke's avatar
Guolin Ke committed
48
                 early_stopping_round = 3)