test_basic.R 3.23 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
context("basic functions")

3
4
data(agaricus.train, package = 'lightgbm')
data(agaricus.test, package = 'lightgbm')
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
9
10
11
train <- agaricus.train
test <- agaricus.test

windows_flag = grepl('Windows', Sys.info()[['sysname']])

test_that("train and predict binary classification", {
  nrounds = 10
12
13
14
15
16
17
18
19
  bst <- lightgbm(
    data = train$data
    , label = train$label
    , num_leaves = 5
    , nrounds = nrounds
    , objective = "binary"
    , metric = "binary_error"
  )
Guolin Ke's avatar
Guolin Ke committed
20
21
22
23
24
  expect_false(is.null(bst$record_evals))
  record_results <- lgb.get.eval.result(bst, "train", "binary_error")
  expect_lt(min(record_results), 0.02)

  pred <- predict(bst, test$data)
25
26
  expect_equal(length(pred), 1611)

Guolin Ke's avatar
Guolin Ke committed
27
  pred1 <- predict(bst, train$data, num_iteration = 1)
28
  expect_equal(length(pred1), 6513)
29
  err_pred1 <- sum( (pred1 > 0.5) != train$label) / length(train$label)
Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
35
36
37
  err_log <- record_results[1]
  expect_lt(abs(err_pred1 - err_log), 10e-6)
})


test_that("train and predict softmax", {
  lb <- as.numeric(iris$Species) - 1

38
39
40
41
42
43
44
45
46
47
48
49
  bst <- lightgbm(
    data = as.matrix(iris[, -5])
    , label = lb
    , num_leaves = 4
    , learning_rate = 0.1
    , nrounds = 20
    , min_data = 20
    , min_hess = 20
    , objective = "multiclass"
    , metric = "multi_error"
    , num_class = 3
  )
Guolin Ke's avatar
Guolin Ke committed
50
51
52
53
54

  expect_false(is.null(bst$record_evals))
  record_results <- lgb.get.eval.result(bst, "train", "multi_error")
  expect_lt(min(record_results), 0.03)

55
56
  pred <- predict(bst, as.matrix(iris[, -5]))
  expect_equal(length(pred), nrow(iris) * 3)
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
})


test_that("use of multiple eval metrics works", {
61
62
63
64
65
66
67
68
69
  bst <- lightgbm(
    data = train$data
    , label = train$label
    , num_leaves = 4
    , learning_rate = 1
    , nrounds = 10
    , objective = "binary"
    , metric = list("binary_error","auc","binary_logloss")
  )
Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
74
  expect_false(is.null(bst$record_evals))
})


test_that("training continuation works", {
75
  testthat::skip("This test is currently broken. See issue #2468 for details.")
76
77
78
79
80
81
82
83
84
85
86
87
  dtrain <- lgb.Dataset(
    train$data
    , label = train$label
    , free_raw_data = FALSE
  )
  watchlist = list(train = dtrain)
  param <- list(
    objective = "binary"
    , metric = "binary_logloss"
    , num_leaves = 5
    , learning_rate = 1
  )
Guolin Ke's avatar
Guolin Ke committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

  # for the reference, use 10 iterations at once:
  bst <- lgb.train(param, dtrain, nrounds = 10, watchlist)
  err_bst <- lgb.get.eval.result(bst, "train", "binary_logloss", 10)
  # first 5 iterations:
  bst1 <- lgb.train(param, dtrain, nrounds = 5, watchlist)
  # test continuing from a model in file
  lgb.save(bst1, "lightgbm.model")
  # continue for 5 more:
  bst2 <- lgb.train(param, dtrain, nrounds = 5, watchlist, init_model = bst1)
  err_bst2 <- lgb.get.eval.result(bst2, "train", "binary_logloss", 10)
  expect_lt(abs(err_bst - err_bst2), 0.01)

  bst2 <- lgb.train(param, dtrain, nrounds = 5, watchlist, init_model = "lightgbm.model")
  err_bst2 <- lgb.get.eval.result(bst2, "train", "binary_logloss", 10)
  expect_lt(abs(err_bst - err_bst2), 0.01)
})

Guolin Ke's avatar
Guolin Ke committed
106
107

test_that("cv works", {
108
109
110
111
112
113
114
115
116
117
118
  dtrain <- lgb.Dataset(train$data, label = train$label)
  params <- list(objective = "regression", metric = "l2,l1")
  bst <- lgb.cv(
    params
    , dtrain
    , 10
    , nfold = 5
    , min_data = 1
    , learning_rate = 1
    , early_stopping_rounds = 10
  )
Guolin Ke's avatar
Guolin Ke committed
119
120
  expect_false(is.null(bst$record_evals))
})