cross_validation.R 1.65 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
require(lightgbm)
# load in the agaricus dataset
3
4
data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm")
Guolin Ke's avatar
Guolin Ke committed
5
dtrain <- lgb.Dataset(agaricus.train$data, label = agaricus.train$label)
6
dtest <- lgb.Dataset.create.valid(dtrain, data = agaricus.test$data, label = agaricus.test$label)
Guolin Ke's avatar
Guolin Ke committed
7

8
nrounds <- 2
9
10
11
12
13
param <- list(
  num_leaves = 4
  , learning_rate = 1
  , objective = "binary"
)
Guolin Ke's avatar
Guolin Ke committed
14

15
16
print("Running cross validation")
# Do cross validation, this will print result out as
Guolin Ke's avatar
Guolin Ke committed
17
18
# [iteration]  metric_name:mean_value+std_value
# std_value is standard deviation of the metric
19
20
21
22
23
24
25
lgb.cv(
  param
  , dtrain
  , nrounds
  , nfold = 5
  , eval = "binary_error"
)
Guolin Ke's avatar
Guolin Ke committed
26

27
print("Running cross validation, disable standard deviation display")
Guolin Ke's avatar
Guolin Ke committed
28
29
30
# do cross validation, this will print result out as
# [iteration]  metric_name:mean_value+std_value
# std_value is standard deviation of the metric
31
32
33
34
35
36
37
38
lgb.cv(
  param
  , dtrain
  , nrounds
  , nfold = 5
  , eval = "binary_error"
  , showsd = FALSE
)
Guolin Ke's avatar
Guolin Ke committed
39

joshkyh's avatar
joshkyh committed
40
# You can also do cross validation with cutomized loss function
41
print("Running cross validation, with cutomsized loss function")
Guolin Ke's avatar
Guolin Ke committed
42
43
44

logregobj <- function(preds, dtrain) {
  labels <- getinfo(dtrain, "label")
45
  preds <- 1 / (1 + exp(-preds))
Guolin Ke's avatar
Guolin Ke committed
46
47
48
49
50
51
  grad <- preds - labels
  hess <- preds * (1 - preds)
  return(list(grad = grad, hess = hess))
}
evalerror <- function(preds, dtrain) {
  labels <- getinfo(dtrain, "label")
52
53
  err <- as.numeric(sum(labels != (preds > 0))) / length(labels)
  return(list(name = "error", value = err, higher_better = FALSE))
Guolin Ke's avatar
Guolin Ke committed
54
55
56
}

# train with customized objective
57
58
59
60
61
62
63
64
lgb.cv(
  params = param
  , data = dtrain
  , nrounds = nrounds
  , obj = logregobj
  , eval = evalerror
  , nfold = 5
)