multiclass_custom_objective.R 3.35 KB
Newer Older
1
library(lightgbm)
2
3
4
5
6
7
8

# We load the default iris dataset shipped with R
data(iris)

# We must convert factors to numeric
# They must be starting from number 0 to use multiclass
# For instance: 0, 1, 2, 3, 4, 5...
9
iris$Species <- as.numeric(as.factor(iris$Species)) - 1L
10

11
# Create imbalanced training data (20, 30, 40 examples for classes 0, 1, 2)
12
train <- as.matrix(iris[c(1L:20L, 51L:80L, 101L:140L), ])
13
# The 10 last samples of each class are for validation
14
test <- as.matrix(iris[c(41L:50L, 91L:100L, 141L:150L), ])
15

16
17
dtrain <- lgb.Dataset(data = train[, 1L:4L], label = train[, 5L])
dtest <- lgb.Dataset.create.valid(dtrain, data = test[, 1L:4L], label = test[, 5L])
18
valids <- list(train = dtrain, test = dtest)
19
20

# Method 1 of training with built-in multiclass objective
21
# Note: need to turn off boost from average to match custom objective
22
# (https://github.com/microsoft/LightGBM/issues/1846)
23
24
25
26
27
28
29
params <- list(
    min_data = 1L
    , learning_rate = 1.0
    , num_class = 3L
    , boost_from_average = FALSE
    , metric = "multi_logloss"
)
30
model_builtin <- lgb.train(
31
    params
32
    , dtrain
33
    , 100L
34
    , valids
35
    , early_stopping_rounds = 10L
36
    , obj = "multiclass"
37
)
38

39
preds_builtin <- predict(model_builtin, test[, 1L:4L], rawscore = TRUE)
40
probs_builtin <- exp(preds_builtin) / rowSums(exp(preds_builtin))
41
42
43
44

# Method 2 of training with custom objective function

# User defined objective function, given prediction, return gradient and second order gradient
45
custom_multiclass_obj <- function(preds, dtrain) {
46
    labels <- get_field(dtrain, "label")
47

48
    # preds is a matrix with rows corresponding to samples and columns corresponding to choices
49
    preds <- matrix(preds, nrow = length(labels))
50

51
    # to prevent overflow, normalize preds by row
52
    preds <- preds - apply(preds, MARGIN = 1L, max)
53
    prob <- exp(preds) / rowSums(exp(preds))
54

55
    # compute gradient
56
    grad <- prob
57
58
59
60
61
62
63
64
65
66
    subset_index <- as.matrix(
        data.frame(
            seq_len(length(labels))
            , labels + 1L
            , fix.empty.names = FALSE
        )
        , nrow = length(labels)
        , dimnames = NULL
    )
    grad[subset_index] <- grad[subset_index] - 1L
67

68
    # compute hessian (approximation)
69
    hess <- 2.0 * prob * (1.0 - prob)
70

71
72
73
    return(list(grad = grad, hess = hess))
}

74
# define custom metric
75
custom_multiclass_metric <- function(preds, dtrain) {
76
    labels <- get_field(dtrain, "label")
77
78
79
    preds <- matrix(preds, nrow = length(labels))
    preds <- preds - apply(preds, 1L, max)
    prob <- exp(preds) / rowSums(exp(preds))
80

81
82
83
84
85
86
87
88
89
    subset_index <- as.matrix(
        data.frame(
            seq_len(length(labels))
            , labels + 1L
            , fix.empty.names = FALSE
        )
        , nrow = length(labels)
        , dimnames = NULL
    )
90
91
    return(list(
        name = "error"
92
        , value = -mean(log(prob[subset_index]))
93
94
        , higher_better = FALSE
    ))
95
96
}

97
98
99
100
101
params <- list(
    min_data = 1L
    , learning_rate = 1.0
    , num_class = 3L
)
102
model_custom <- lgb.train(
103
    params
104
    , dtrain
105
    , 100L
106
    , valids
107
    , early_stopping_rounds = 10L
108
    , obj = custom_multiclass_obj
109
110
    , eval = custom_multiclass_metric
)
111

112
preds_custom <- predict(model_custom, test[, 1L:4L], rawscore = TRUE)
113
probs_custom <- exp(preds_custom) / rowSums(exp(preds_custom))
114
115

# compare predictions
116
117
stopifnot(identical(probs_builtin, probs_custom))
stopifnot(identical(preds_builtin, preds_custom))