test_lgb.interprete.R 3.65 KB
Newer Older
1
2
3
4
VERBOSITY <- as.integer(
    Sys.getenv("LIGHTGBM_TEST_VERBOSITY", "-1")
)

5
6
context("lgb.interpete")

7
8
.sigmoid <- function(x) {
    1.0 / (1.0 + exp(-x))
9
}
10
11
.logit <- function(x) {
    log(x / (1.0 - x))
12
13
14
15
16
17
}

test_that("lgb.intereprete works as expected for binary classification", {
    data(agaricus.train, package = "lightgbm")
    train <- agaricus.train
    dtrain <- lgb.Dataset(train$data, label = train$label)
18
    set_field(
19
        dataset = dtrain
20
21
        , field_name = "init_score"
        , data = rep(
22
23
24
25
26
27
28
29
30
            .logit(mean(train$label))
            , length(train$label)
        )
    )
    data(agaricus.test, package = "lightgbm")
    test <- agaricus.test
    params <- list(
        objective = "binary"
        , learning_rate = 0.01
31
32
33
34
        , num_leaves = 63L
        , max_depth = -1L
        , min_data_in_leaf = 1L
        , min_sum_hessian_in_leaf = 1.0
35
        , verbose = VERBOSITY
36
37
38
39
    )
    model <- lgb.train(
        params = params
        , data = dtrain
40
        , nrounds = 3L
41
    )
42
    num_trees <- 5L
43
44
45
    tree_interpretation <- lgb.interprete(
        model = model
        , data = test$data
46
        , idxset = seq_len(num_trees)
47
    )
48
    expect_identical(class(tree_interpretation), "list")
49
50
51
52
53
    expect_true(length(tree_interpretation) == num_trees)
    expect_null(names(tree_interpretation))
    expect_true(all(
        sapply(
            X = tree_interpretation
54
            , FUN = function(treeDT) {
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
                checks <- c(
                    data.table::is.data.table(treeDT)
                    , identical(names(treeDT), c("Feature", "Contribution"))
                    , is.character(treeDT[, Feature])
                    , is.numeric(treeDT[, Contribution])
                )
                return(all(checks))
            }
        )
    ))
})

test_that("lgb.intereprete works as expected for multiclass classification", {
    data(iris)

    # We must convert factors to numeric
    # They must be starting from number 0 to use multiclass
    # For instance: 0, 1, 2, 3, 4, 5...
73
    iris$Species <- as.numeric(as.factor(iris$Species)) - 1L
74
75

    # Create imbalanced training data (20, 30, 40 examples for classes 0, 1, 2)
76
    train <- as.matrix(iris[c(1L:20L, 51L:80L, 101L:140L), ])
77
    # The 10 last samples of each class are for validation
78
79
80
    test <- as.matrix(iris[c(41L:50L, 91L:100L, 141L:150L), ])
    dtrain <- lgb.Dataset(data = train[, 1L:4L], label = train[, 5L])
    dtest <- lgb.Dataset.create.valid(dtrain, data = test[, 1L:4L], label = test[, 5L])
81
82
83
    params <- list(
        objective = "multiclass"
        , metric = "multi_logloss"
84
        , num_class = 3L
85
        , learning_rate = 0.00001
86
        , min_data = 1L
87
        , verbose = VERBOSITY
88
89
90
91
    )
    model <- lgb.train(
        params = params
        , data = dtrain
92
        , nrounds = 3L
93
    )
94
    num_trees <- 5L
95
96
    tree_interpretation <- lgb.interprete(
        model = model
97
98
        , data = test[, 1L:4L]
        , idxset = seq_len(num_trees)
99
    )
100
    expect_identical(class(tree_interpretation), "list")
101
102
103
104
105
    expect_true(length(tree_interpretation) == num_trees)
    expect_null(names(tree_interpretation))
    expect_true(all(
        sapply(
            X = tree_interpretation
106
            , FUN = function(treeDT) {
107
108
109
110
111
112
113
114
115
116
117
118
119
                checks <- c(
                    data.table::is.data.table(treeDT)
                    , identical(names(treeDT), c("Feature", "Class 0", "Class 1", "Class 2"))
                    , is.character(treeDT[, Feature])
                    , is.numeric(treeDT[, `Class 0`])
                    , is.numeric(treeDT[, `Class 1`])
                    , is.numeric(treeDT[, `Class 2`])
                )
                return(all(checks))
            }
        )
    ))
})