test_custom_objective.R 1.01 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
context('Test models with custom objective')

3
4
data(agaricus.train, package = 'lightgbm')
data(agaricus.test, package = 'lightgbm')
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
dtrain <- lgb.Dataset(agaricus.train$data, label = agaricus.train$label)
dtest <- lgb.Dataset(agaricus.test$data, label = agaricus.test$label)
watchlist <- list(eval = dtest, train = dtrain)

logregobj <- function(preds, dtrain) {
  labels <- getinfo(dtrain, "label")
  preds <- 1 / (1 + exp(-preds))
  grad <- preds - labels
  hess <- preds * (1 - preds)
  return(list(grad = grad, hess = hess))
}

evalerror <- function(preds, dtrain) {
  labels <- getinfo(dtrain, "label")
  err <- as.numeric(sum(labels != (preds > 0))) / length(labels)
20
21
22
23
24
  return(list(
    name = "error"
    , value = err
    , higher_better = FALSE
  ))
Guolin Ke's avatar
Guolin Ke committed
25
26
}

27
28
29
30
31
32
param <- list(
  num_leaves = 8
  , learning_rate = 1
  , objective = logregobj
  , metric = "auc"
)
Guolin Ke's avatar
Guolin Ke committed
33
34
35
num_round <- 10

test_that("custom objective works", {
36
  bst <- lgb.train(param, dtrain, num_round, watchlist, eval = evalerror)
Guolin Ke's avatar
Guolin Ke committed
37
38
  expect_false(is.null(bst$record_evals))
})