test_lgb.plot.interpretation.R 2.85 KB
Newer Older
1
2
.sigmoid <- function(x) {
    1.0 / (1.0 + exp(-x))
3
}
4
5
.logit <- function(x) {
    log(x / (1.0 - x))
6
7
}

8
test_that("lgb.plot.interpretation works as expected for binary classification", {
9
10
11
    data(agaricus.train, package = "lightgbm")
    train <- agaricus.train
    dtrain <- lgb.Dataset(train$data, label = train$label)
12
    set_field(
13
        dataset = dtrain
14
15
        , field_name = "init_score"
        , data = rep(
16
17
18
19
20
21
22
23
24
            .logit(mean(train$label))
            , length(train$label)
        )
    )
    data(agaricus.test, package = "lightgbm")
    test <- agaricus.test
    params <- list(
        objective = "binary"
        , learning_rate = 0.01
25
26
27
28
        , num_leaves = 63L
        , max_depth = -1L
        , min_data_in_leaf = 1L
        , min_sum_hessian_in_leaf = 1.0
29
        , verbosity = .LGB_VERBOSITY
30
        , num_threads = .LGB_MAX_THREADS
31
32
33
34
    )
    model <- lgb.train(
        params = params
        , data = dtrain
35
        , nrounds = 3L
36
    )
37
    num_trees <- 5L
38
39
40
    tree_interpretation <- lgb.interprete(
        model = model
        , data = test$data
41
        , idxset = seq_len(num_trees)
42
43
44
    )
    expect_true({
        lgb.plot.interpretation(
45
46
            tree_interpretation_dt = tree_interpretation[[1L]]
            , top_n = 5L
47
48
49
50
51
52
        )
        TRUE
    })

    # should also work when you explicitly pass cex
    plot_res <- lgb.plot.interpretation(
53
54
        tree_interpretation_dt = tree_interpretation[[1L]]
        , top_n = 5L
55
56
57
58
59
        , cex = 0.95
    )
    expect_null(plot_res)
})

60
test_that("lgb.plot.interpretation works as expected for multiclass classification", {
61
62
63
64
65
    data(iris)

    # We must convert factors to numeric
    # They must be starting from number 0 to use multiclass
    # For instance: 0, 1, 2, 3, 4, 5...
66
    iris$Species <- as.numeric(as.factor(iris$Species)) - 1L
67
68

    # Create imbalanced training data (20, 30, 40 examples for classes 0, 1, 2)
69
    train <- as.matrix(iris[c(1L:20L, 51L:80L, 101L:140L), ])
70
    # The 10 last samples of each class are for validation
71
72
73
    test <- as.matrix(iris[c(41L:50L, 91L:100L, 141L:150L), ])
    dtrain <- lgb.Dataset(data = train[, 1L:4L], label = train[, 5L])
    dtest <- lgb.Dataset.create.valid(dtrain, data = test[, 1L:4L], label = test[, 5L])
74
75
76
    params <- list(
        objective = "multiclass"
        , metric = "multi_logloss"
77
        , num_class = 3L
78
        , learning_rate = 0.00001
79
        , min_data = 1L
80
        , num_threads = .LGB_MAX_THREADS
81
82
83
84
    )
    model <- lgb.train(
        params = params
        , data = dtrain
85
        , nrounds = 3L
86
        , verbose = .LGB_VERBOSITY
87
    )
88
    num_trees <- 5L
89
90
    tree_interpretation <- lgb.interprete(
        model = model
91
92
        , data = test[, 1L:4L]
        , idxset = seq_len(num_trees)
93
94
    )
    plot_res <- lgb.plot.interpretation(
95
96
        tree_interpretation_dt = tree_interpretation[[1L]]
        , top_n = 5L
97
98
99
    )
    expect_null(plot_res)
})