"vscode:/vscode.git/clone" did not exist on "3837e60de6747a37917989998fdc8b4ad25af1d1"
test_utils.R 3.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
context("lgb.encode.char")

test_that("lgb.encode.char throws an informative error if it is passed a non-raw input", {
    x <- "some-string"
    expect_error({
        lgb.encode.char(x)
    }, regexp = "Can only encode from raw type")
})

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
context("lgb.check.r6.class")

test_that("lgb.check.r6.class() should return FALSE for NULL input", {
    expect_false(lgb.check.r6.class(NULL, "lgb.Dataset"))
})

test_that("lgb.check.r6.class() should return FALSE for non-R6 inputs", {
    x <- 5L
    class(x) <- "lgb.Dataset"
    expect_false(lgb.check.r6.class(x, "lgb.Dataset"))
})

test_that("lgb.check.r6.class() should correctly identify lgb.Dataset", {

    data("agaricus.train", package = "lightgbm")
    train <- agaricus.train
    ds <- lgb.Dataset(train$data, label = train$label)
    expect_true(lgb.check.r6.class(ds, "lgb.Dataset"))
    expect_false(lgb.check.r6.class(ds, "lgb.Predictor"))
    expect_false(lgb.check.r6.class(ds, "lgb.Booster"))
})
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

context("lgb.params2str")

test_that("lgb.params2str() works as expected for empty lists", {
    out_str <- lgb.params2str(
        params = list()
    )
    expect_identical(class(out_str), "raw")
    expect_equal(out_str, lgb.c_str(""))
})

test_that("lgb.params2str() works as expected for a key in params with multiple different-length elements", {
    metrics <- c("a", "ab", "abc", "abcdefg")
    params <- list(
        objective = "magic"
        , metric = metrics
        , nrounds = 10L
        , learning_rate = 0.0000001
    )
    out_str <- lgb.params2str(
        params = params
    )
    expect_identical(class(out_str), "raw")
    out_as_char <- rawToChar(out_str)
    expect_identical(
        out_as_char
        , "objective=magic metric=a,ab,abc,abcdefg nrounds=10 learning_rate=0.0000001"
    )
})
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

context("lgb.last_error")

test_that("lgb.last_error() throws an error if there are no errors", {
    expect_error({
        lgb.last_error()
    }, regexp = "Everything is fine")
})

test_that("lgb.last_error() correctly returns errors from the C++ side", {
    data(agaricus.train, package = "lightgbm")
    train <- agaricus.train
    dvalid1 <- lgb.Dataset(
        data = train$data
        , label = as.matrix(rnorm(5L))
    )
    expect_error({
        dvalid1$construct()
    }, regexp = "[LightGBM] [Fatal] Length of label is not same with #data", fixed = TRUE)
})
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

context("lgb.check.eval")

test_that("lgb.check.eval works as expected with no metric", {
    params <- lgb.check.eval(
        params = list(device = "cpu")
        , eval = "binary_error"
    )
    expect_named(params, c("device", "metric"))
    expect_identical(params[["metric"]], list("binary_error"))
})

test_that("lgb.check.eval adds eval to metric in params", {
    params <- lgb.check.eval(
        params = list(metric = "auc")
        , eval = "binary_error"
    )
    expect_named(params, "metric")
    expect_identical(params[["metric"]], list("auc", "binary_error"))
})

test_that("lgb.check.eval adds eval to metric in params if two evaluation names are provided", {
    params <- lgb.check.eval(
        params = list(metric = "auc")
        , eval = c("binary_error", "binary_logloss")
    )
    expect_named(params, "metric")
    expect_identical(params[["metric"]], list("auc", "binary_error", "binary_logloss"))
})

test_that("lgb.check.eval adds eval to metric in params if a list is provided", {
    params <- lgb.check.eval(
        params = list(metric = "auc")
        , eval = list("binary_error", "binary_logloss")
    )
    expect_named(params, "metric")
    expect_identical(params[["metric"]], list("auc", "binary_error", "binary_logloss"))
})