leaf_stability.R 6.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# We are going to look at how iterating too much might generate observation instability.
# Obviously, we are in a controlled environment, without issues (real rules).
# Do not do this in a real scenario.

# First, we load our libraries
library(lightgbm)
library(ggplot2)

# Second, we load our data
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
data(agaricus.test, package = "lightgbm")
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)

# Third, we setup parameters and we train a model
params <- list(objective = "regression", metric = "l2")
valids <- list(test = dtest)
20
21
22
model <- lgb.train(
    params
    , dtrain
23
    , 50L
24
    , valids
25
    , min_data = 1L
26
27
    , learning_rate = 0.1
    , bagging_fraction = 0.1
28
29
    , bagging_freq = 1L
    , bagging_seed = 1L
30
)
31
32
33
34
35
36

# We create a data.frame with the following structure:
# X = average leaf of the observation throughout all trees
# Y = prediction probability (clamped to [1e-15, 1-1e-15])
# Z = logloss
# binned = binned quantile of average leaf
37
38
39
40
41
42
43
44
45
46
47
new_data <- data.frame(
    X = rowMeans(predict(
        model
        , agaricus.test$data
        , predleaf = TRUE
    ))
    , Y = pmin(
        pmax(
            predict(model, agaricus.test$data)
            , 1e-15
        )
48
        , 1.0 - 1e-15
49
50
    )
)
51
new_data$Z <- -1.0 * (agaricus.test$label * log(new_data$Y) + (1L - agaricus.test$label) * log(1L - new_data$Y))
52
53
54
55
new_data$binned <- .bincode(
    x = new_data$X
    , breaks = quantile(
        x = new_data$X
56
        , probs = seq_len(9L) / 10.0
57
58
59
60
    )
    , right = TRUE
    , include.lowest = TRUE
)
61
new_data$binned[is.na(new_data$binned)] <- 0L
62
63
64
65
66
67
68
69
new_data$binned <- as.factor(new_data$binned)

# We can check the binned content
table(new_data$binned)

# We can plot the binned content
# On the second plot, we clearly notice the lower the bin (the lower the leaf value), the higher the loss
# On the third plot, it is smooth!
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
ggplot(
    data = new_data
    , mapping = aes(x = X, y = Y, color = binned)
) + geom_point() +
  theme_bw() +
  labs(title = "Prediction Depth", x = "Leaf Bin", y = "Prediction Probability")
ggplot(
    data = new_data
    , mapping = aes(x = binned, y = Z, fill = binned, group = binned)
) + geom_boxplot() +
  theme_bw() +
  labs(title = "Prediction Depth Spread", x = "Leaf Bin", y = "Logloss")
ggplot(
    data = new_data
    , mapping = aes(x = Y, y = ..count.., fill = binned)
) + geom_density(position = "fill") +
  theme_bw() +
  labs(title = "Depth Density", x = "Prediction Probability", y = "Bin Density")
88
89
90


# Now, let's show with other parameters
91
92
93
model2 <- lgb.train(
    params
    , dtrain
94
    , 100L
95
    , valids
96
97
    , min_data = 1L
    , learning_rate = 1.0
98
)
99
100

# We create the data structure, but for model2
101
102
103
104
105
106
107
108
109
110
111
112
113
114
new_data2 <- data.frame(
    X = rowMeans(predict(
        model2
        , agaricus.test$data
        , predleaf = TRUE
    ))
    , Y = pmin(
        pmax(
            predict(
                model2
                , agaricus.test$data
            )
            , 1e-15
        )
115
      , 1.0 - 1e-15
116
117
     )
)
118
new_data2$Z <- -1.0 * (agaricus.test$label * log(new_data2$Y) + (1L - agaricus.test$label) * log(1L - new_data2$Y))
119
120
121
122
new_data2$binned <- .bincode(
    x = new_data2$X
    , breaks = quantile(
        x = new_data2$X
123
        , probs = seq_len(9L) / 10.0
124
125
126
127
    )
    , right = TRUE
    , include.lowest = TRUE
)
128
new_data2$binned[is.na(new_data2$binned)] <- 0L
129
130
131
132
133
134
135
new_data2$binned <- as.factor(new_data2$binned)

# We can check the binned content
table(new_data2$binned)

# We can plot the binned content
# On the second plot, we clearly notice the lower the bin (the lower the leaf value), the higher the loss
136
137
# On the third plot, it is clearly not smooth! We are severely overfitting the data, but the rules are
# real thus it is not an issue
138
# However, if the rules were not true, the loss would explode.
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
ggplot(
    data = new_data2
    , mapping = aes(x = X, y = Y, color = binned)
) + geom_point() +
  theme_bw() +
  labs(title = "Prediction Depth", x = "Leaf Bin", y = "Prediction Probability")
ggplot(
    data = new_data2
    , mapping = aes(x = binned, y = Z, fill = binned, group = binned)
) + geom_boxplot() +
  theme_bw() +
  labs(title = "Prediction Depth Spread", x = "Leaf Bin", y = "Logloss")
ggplot(
    data = new_data2
    , mapping = aes(x = Y, y = ..count.., fill = binned)
) + geom_density(position = "fill") +
  theme_bw() +
  labs(title = "Depth Density", x = "Prediction Probability", y = "Bin Density")
157
158
159


# Now, try with very severe overfitting
160
161
162
model3 <- lgb.train(
    params
    , dtrain
163
    , 1000L
164
    , valids
165
166
    , min_data = 1L
    , learning_rate = 1.0
167
)
168
169

# We create the data structure, but for model3
170
171
172
173
174
175
176
177
178
179
180
181
182
183
new_data3 <- data.frame(
    X = rowMeans(predict(
        model3
        , agaricus.test$data
        , predleaf = TRUE
    ))
    , Y = pmin(
        pmax(
            predict(
                model3
                , agaricus.test$data
            )
            , 1e-15
        )
184
        , 1.0 - 1e-15
185
186
    )
)
187
new_data3$Z <- -1.0 * (agaricus.test$label * log(new_data3$Y) + (1L - agaricus.test$label) * log(1L - new_data3$Y))
188
189
190
191
new_data3$binned <- .bincode(
    x = new_data3$X
    , breaks = quantile(
        x = new_data3$X
192
        , probs = seq_len(9L) / 10.0
193
194
195
196
    )
    , right = TRUE
    , include.lowest = TRUE
)
197
new_data3$binned[is.na(new_data3$binned)] <- 0L
198
199
200
201
202
203
new_data3$binned <- as.factor(new_data3$binned)

# We can check the binned content
table(new_data3$binned)

# We can plot the binned content
204
205
# On the third plot, it is clearly not smooth! We are severely overfitting the data, but the rules
# are real thus it is not an issue.
206
# However, if the rules were not true, the loss would explode. See the sudden spikes?
207
208
209
210
211
212
213
ggplot(
    data = new_data3
    , mapping = aes(x = Y, y = ..count.., fill = binned)
) +
  geom_density(position = "fill") +
  theme_bw() +
  labs(title = "Depth Density", x = "Prediction Probability", y = "Bin Density")
214
215

# Compare with our second model, the difference is severe. This is smooth.
216
217
218
219
220
221
ggplot(
    data = new_data2
    , mapping = aes(x = Y, y = ..count.., fill = binned)
) + geom_density(position = "fill") +
  theme_bw() +
  labs(title = "Depth Density", x = "Prediction Probability", y = "Bin Density")