cross_validation.R 1.68 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
require(lightgbm)
# load in the agaricus dataset
3
4
data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm")
Guolin Ke's avatar
Guolin Ke committed
5
6
7
dtrain <- lgb.Dataset(agaricus.train$data, label = agaricus.train$label)
dtest <- lgb.Dataset(agaricus.test$data, label = agaricus.test$label)

8
9
10
11
nrounds <- 2
param <- list(num_leaves = 4,
              learning_rate = 1,
              objective = "binary")
Guolin Ke's avatar
Guolin Ke committed
12

13
14
print("Running cross validation")
# Do cross validation, this will print result out as
Guolin Ke's avatar
Guolin Ke committed
15
16
# [iteration]  metric_name:mean_value+std_value
# std_value is standard deviation of the metric
17
18
19
20
21
lgb.cv(param,
       dtrain,
       nrounds,
       nfold = 5,
       eval = "binary_error")
Guolin Ke's avatar
Guolin Ke committed
22

23
print("Running cross validation, disable standard deviation display")
Guolin Ke's avatar
Guolin Ke committed
24
25
26
# do cross validation, this will print result out as
# [iteration]  metric_name:mean_value+std_value
# std_value is standard deviation of the metric
27
28
29
30
31
32
lgb.cv(param,
       dtrain,
       nrounds,
       nfold = 5,
       eval = "binary_error",
       showsd = FALSE)
Guolin Ke's avatar
Guolin Ke committed
33

joshkyh's avatar
joshkyh committed
34
# You can also do cross validation with cutomized loss function
35
print("Running cross validation, with cutomsized loss function")
Guolin Ke's avatar
Guolin Ke committed
36
37
38

logregobj <- function(preds, dtrain) {
  labels <- getinfo(dtrain, "label")
39
  preds <- 1 / (1 + exp(-preds))
Guolin Ke's avatar
Guolin Ke committed
40
41
42
43
44
45
  grad <- preds - labels
  hess <- preds * (1 - preds)
  return(list(grad = grad, hess = hess))
}
evalerror <- function(preds, dtrain) {
  labels <- getinfo(dtrain, "label")
46
47
  err <- as.numeric(sum(labels != (preds > 0))) / length(labels)
  return(list(name = "error", value = err, higher_better = FALSE))
Guolin Ke's avatar
Guolin Ke committed
48
49
50
}

# train with customized objective
51
52
53
54
55
56
lgb.cv(params = param,
       data = dtrain,
       nrounds = nrounds,
       obj = logregobj,
       eval = evalerror,
       nfold = 5)